MarkTechPost@AI 前天 14:04
Building a GPU-Accelerated Ollama LangChain Workflow with RAG Agents, Multi-Session Chat Performance Monitoring
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文档详细介绍了如何在本地构建一个支持GPU的LLM(大语言模型)堆栈,该堆栈整合了Ollama和LangChain。教程涵盖了从安装必要库、启动Ollama服务器、拉取模型,到将其封装为自定义LangChain LLM,并配置温度、令牌限制和上下文等参数。此外,还介绍了如何添加检索增强生成(RAG)层,该层能够处理PDF或文本文件,进行分块、嵌入,并提供基于上下文的答案。文章还指导了多会话聊天内存的管理,以及注册工具(如网络搜索和RAG查询)并构建能够智能调用这些工具的代理。

🚀 **GPU加速的本地LLM环境搭建**:教程详细指导用户在本地环境中安装和配置Ollama,一个能够运行大语言模型的平台,并充分利用GPU资源以提升模型推理性能。通过设置环境变量如`OLLAMA_NUM_PARALLEL`,优化了并行处理能力,为后续的复杂任务奠定基础。

🧠 **LangChain LLM的Ollama封装与集成**:文章展示了如何创建一个自定义的LangChain LLM类,将Ollama的API封装起来。这使得Ollama的模型能够无缝接入LangChain的生态系统,用户可以方便地控制模型的生成参数(如温度、最大令牌数)以及模型本身的配置,实现高度的灵活性。

📚 **检索增强生成(RAG)实现**:通过集成`sentence-transformers`进行文本嵌入和`FAISS`作为向量数据库,教程演示了如何构建一个RAG系统。该系统能够处理用户上传的PDF或文本文件,将文档内容分割、嵌入并存储,从而支持基于文档内容的问答,实现更精准和有依据的回答。

💬 **多会话管理与智能代理构建**:系统支持多会话聊天记忆管理,包括基于缓冲和摘要的两种模式,允许用户在不同会话中保持上下文。同时,教程介绍了如何注册工具(如网络搜索和RAG查询),并构建一个智能代理(Agent),该代理能够根据用户需求自主判断并调用合适的工具来完成任务,极大地增强了系统的应用能力。

In this tutorial, we build a GPU‑capable local LLM stack that unifies Ollama and LangChain. We install the required libraries, launch the Ollama server, pull a model, and wrap it in a custom LangChain LLM, allowing us to control temperature, token limits, and context. We add a Retrieval-Augmented Generation layer that ingests PDFs or text, chunks them, embeds them with Sentence-Transformers, and serves grounded answers. We manage multi‑session chat memory, register tools (web search + RAG query), and spin up an agent that reasons about when to call them.

import osimport sysimport subprocessimport timeimport threadingimport queueimport jsonfrom typing import List, Dict, Any, Optional, Tuplefrom dataclasses import dataclassfrom contextlib import contextmanagerimport asynciofrom concurrent.futures import ThreadPoolExecutordef install_packages():    """Install required packages for Colab environment"""    packages = [        "langchain",        "langchain-community",        "langchain-core",        "chromadb",        "sentence-transformers",        "faiss-cpu",        "pypdf",        "python-docx",        "requests",        "psutil",        "pyngrok",        "gradio"    ]       for package in packages:        subprocess.check_call([sys.executable, "-m", "pip", "install", package])install_packages()import requestsimport psutilimport threadingfrom queue import Queuefrom langchain.llms.base import LLMfrom langchain.callbacks.manager import CallbackManagerForLLMRunfrom langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessagefrom langchain.memory import ConversationBufferWindowMemory, ConversationSummaryBufferMemoryfrom langchain.chains import ConversationChain, RetrievalQAfrom langchain.prompts import PromptTemplate, ChatPromptTemplatefrom langchain.document_loaders import PyPDFLoader, TextLoaderfrom langchain.text_splitter import RecursiveCharacterTextSplitterfrom langchain.embeddings import HuggingFaceEmbeddingsfrom langchain.vectorstores import FAISS, Chromafrom langchain.agents import AgentType, initialize_agent, Toolfrom langchain.tools import DuckDuckGoSearchRun

We import the necessary Python utilities in Colab for concurrency, system calls, and JSON handling. We define and run install_packages() to pull LangChain, embeddings, vector stores, document loaders, monitoring, and UI dependencies. We then import LangChain LLM, memory, retrieval, and agent tools (including DuckDuckGo search) to build an extensible RAG and agent workflow.

[Download the full codes with notebook here]

@dataclassclass OllamaConfig:    """Configuration for Ollama setup"""    model_name: str = "llama2"    base_url: str = "http://localhost:11434"    max_tokens: int = 2048    temperature: float = 0.7    gpu_layers: int = -1      context_window: int = 4096    batch_size: int = 512    threads: int = 4

We define an OllamaConfig dataclass so we keep all Ollama runtime settings in one clean place. We set the model name and local API endpoint, as well as the generation behavior (max_tokens, temperature, and context_window). We control performance with gpu_layers (‑1 = load all to GPU when possible), batch_size, and threads for parallelism.

@dataclassclass OllamaConfig:    """Configuration for Ollama setup"""    model_name: str = "llama2"    base_url: str = "http://localhost:11434"    max_tokens: int = 2048    temperature: float = 0.7    gpu_layers: int = -1      context_window: int = 4096    batch_size: int = 512    threads: int = 4We define an OllamaConfig dataclass so we keep all Ollama runtime settings in one clean place. We set the model name and local API endpoint, as well as the generation behavior (max_tokens, temperature, and context_window). We control performance with gpu_layers (‑1 = load all to GPU when possible), batch_size, and threads for parallelism.class OllamaManager:    """Advanced Ollama manager for Colab environment"""       def __init__(self, config: OllamaConfig):        self.config = config        self.process = None        self.is_running = False        self.models_cache = {}        self.performance_monitor = PerformanceMonitor()           def install_ollama(self):        """Install Ollama in Colab environment"""        try:            subprocess.run([                "curl", "-fsSL", "https://ollama.com/install.sh", "-o", "/tmp/install.sh"            ], check=True)                       subprocess.run(["bash", "/tmp/install.sh"], check=True)            print(" Ollama installed successfully")                   except subprocess.CalledProcessError as e:            print(f" Failed to install Ollama: {e}")            raise       def start_server(self):        """Start Ollama server with GPU support"""        if self.is_running:            print("Ollama server is already running")            return                   try:            env = os.environ.copy()            env["OLLAMA_NUM_PARALLEL"] = str(self.config.threads)            env["OLLAMA_MAX_LOADED_MODELS"] = "3"                       self.process = subprocess.Popen(                ["ollama", "serve"],                env=env,                stdout=subprocess.PIPE,                stderr=subprocess.PIPE            )                       time.sleep(5)                       if self.health_check():                self.is_running = True                print(" Ollama server started successfully")                self.performance_monitor.start()            else:                raise Exception("Server failed to start properly")                       except Exception as e:            print(f" Failed to start Ollama server: {e}")            raise       def health_check(self) -> bool:        """Check if Ollama server is healthy"""        try:            response = requests.get(f"{self.config.base_url}/api/tags", timeout=10)            return response.status_code == 200        except:            return False       def pull_model(self, model_name: str) -> bool:        """Pull a model from Ollama registry"""        try:            print(f" Pulling model: {model_name}")            result = subprocess.run(                ["ollama", "pull", model_name],                capture_output=True,                text=True,                timeout=1800              )                       if result.returncode == 0:                print(f" Model {model_name} pulled successfully")                self.models_cache[model_name] = True                return True            else:                print(f" Failed to pull model {model_name}: {result.stderr}")                return False                       except subprocess.TimeoutExpired:            print(f" Timeout pulling model {model_name}")            return False        except Exception as e:            print(f" Error pulling model {model_name}: {e}")            return False       def list_models(self) -> List[str]:        """List available local models"""        try:            result = subprocess.run(                ["ollama", "list"],                capture_output=True,                text=True            )                       models = []            for line in result.stdout.split('\n')[1:]:                if line.strip():                    model_name = line.split()[0]                    models.append(model_name)                               return models                   except Exception as e:            print(f" Error listing models: {e}")            return []       def stop_server(self):        """Stop Ollama server"""        if self.process:            self.process.terminate()            self.process.wait()            self.is_running = False            self.performance_monitor.stop()            print(" Ollama server stopped")

We create the OllamaManager class to install, start, monitor, and manage the Ollama server in the Colab environment. We set environment variables for GPU parallelism, run the server in the background, and verify it’s up with a health check. We pull models on demand, cache them, list available ones locally, and gracefully shut down the server when the task is complete, all while tracking performance.

[Download the full codes with notebook here]

class PerformanceMonitor:    """Monitor system performance and resource usage"""       def __init__(self):        self.monitoring = False        self.stats = {            "cpu_usage": [],            "memory_usage": [],            "gpu_usage": [],            "inference_times": []        }        self.monitor_thread = None       def start(self):        """Start performance monitoring"""        self.monitoring = True        self.monitor_thread = threading.Thread(target=self._monitor_loop)        self.monitor_thread.daemon = True        self.monitor_thread.start()       def stop(self):        """Stop performance monitoring"""        self.monitoring = False        if self.monitor_thread:            self.monitor_thread.join()       def _monitor_loop(self):        """Main monitoring loop"""        while self.monitoring:            try:                cpu_percent = psutil.cpu_percent(interval=1)                memory = psutil.virtual_memory()                               self.stats["cpu_usage"].append(cpu_percent)                self.stats["memory_usage"].append(memory.percent)                               for key in ["cpu_usage", "memory_usage"]:                    if len(self.stats[key]) > 100:                        self.stats[key] = self.stats[key][-100:]                               time.sleep(5)                           except Exception as e:                print(f"Monitoring error: {e}")       def get_stats(self) -> Dict[str, Any]:        """Get current performance statistics"""        return {            "avg_cpu": sum(self.stats["cpu_usage"][-10:]) / max(len(self.stats["cpu_usage"][-10:]), 1),            "avg_memory": sum(self.stats["memory_usage"][-10:]) / max(len(self.stats["memory_usage"][-10:]), 1),            "total_inferences": len(self.stats["inference_times"]),            "avg_inference_time": sum(self.stats["inference_times"]) / max(len(self.stats["inference_times"]), 1)        }

We define a PerformanceMonitor class to track CPU, memory, and inference times in real-time while the Ollama server runs. We launch a background thread to collect stats every few seconds, store recent metrics, and provide average usage summaries. This helps us monitor system load and optimize performance during model inference.

[Download the full codes with notebook here]

class OllamaLLM(LLM):    """Custom LangChain LLM for Ollama"""       model_name: str = "llama2"    base_url: str = "http://localhost:11434"    temperature: float = 0.7    max_tokens: int = 2048    performance_monitor: Optional[PerformanceMonitor] = None       @property    def _llm_type(self) -> str:        return "ollama"       def _call(        self,        prompt: str,        stop: Optional[List[str]] = None,        run_manager: Optional[CallbackManagerForLLMRun] = None,        **kwargs: Any,    ) -> str:        """Make API call to Ollama"""        start_time = time.time()               try:            payload = {                "model": self.model_name,                "prompt": prompt,                "stream": False,                "options": {                    "temperature": self.temperature,                    "num_predict": self.max_tokens,                    "stop": stop or []                }            }                       response = requests.post(                f"{self.base_url}/api/generate",                json=payload,                timeout=120            )                       response.raise_for_status()            result = response.json()                       inference_time = time.time() - start_time                       if self.performance_monitor:                self.performance_monitor.stats["inference_times"].append(inference_time)                       return result.get("response", "")                   except Exception as e:            print(f" Ollama API error: {e}")            return f"Error: {str(e)}"

We wrap the Ollama API inside a custom OllamaLLM class compatible with LangChain’s LLM interface. We define how prompts are sent to the Ollama server and record each inference time for performance tracking. This lets us plug Ollama directly into LangChain chains, agents, and memory components while monitoring efficiency.

class RAGSystem:    """Retrieval-Augmented Generation system"""       def __init__(self, llm: OllamaLLM, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):        self.llm = llm        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_model)        self.vector_store = None        self.qa_chain = None        self.text_splitter = RecursiveCharacterTextSplitter(            chunk_size=1000,            chunk_overlap=200,            length_function=len        )       def add_documents(self, file_paths: List[str]):        """Add documents to the vector store"""        documents = []               for file_path in file_paths:            try:                if file_path.endswith('.pdf'):                    loader = PyPDFLoader(file_path)                else:                    loader = TextLoader(file_path)                               docs = loader.load()                documents.extend(docs)                           except Exception as e:                print(f" Error loading {file_path}: {e}")               if documents:            splits = self.text_splitter.split_documents(documents)                       if self.vector_store is None:                self.vector_store = FAISS.from_documents(splits, self.embeddings)            else:                self.vector_store.add_documents(splits)                       self.qa_chain = RetrievalQA.from_chain_type(                llm=self.llm,                chain_type="stuff",                retriever=self.vector_store.as_retriever(search_kwargs={"k": 3}),                return_source_documents=True            )                       print(f" Added {len(splits)} document chunks to vector store")       def query(self, question: str) -> Dict[str, Any]:        """Query the RAG system"""        if not self.qa_chain:            return {"answer": "No documents loaded. Please add documents first."}               try:            result = self.qa_chain({"query": question})            return {                "answer": result["result"],                "sources": [doc.metadata for doc in result.get("source_documents", [])]            }        except Exception as e:            return {"answer": f"Error: {str(e)}"}

We use ConversationManager to manage multi-session memory, enabling both buffer-based and summary-based chat histories for each session. Then, in OllamaLangChainSystem, we bring all components together, server, LLM, RAG, memory, tools, and agents, into one unified interface. We configure the system to install Ollama, pull models, build agents with tools like web search and RAG, and expose chat, document upload, and model-switching capabilities for seamless interaction.

class ConversationManager:    """Manage conversation history and memory"""       def __init__(self, llm: OllamaLLM, memory_type: str = "buffer"):        self.llm = llm        self.conversations = {}        self.memory_type = memory_type           def get_conversation(self, session_id: str) -> ConversationChain:        """Get or create conversation for session"""        if session_id not in self.conversations:            if self.memory_type == "buffer":                memory = ConversationBufferWindowMemory(k=10)            elif self.memory_type == "summary":                memory = ConversationSummaryBufferMemory(                    llm=self.llm,                    max_token_limit=1000                )            else:                memory = ConversationBufferWindowMemory(k=10)                       self.conversations[session_id] = ConversationChain(                llm=self.llm,                memory=memory,                verbose=True            )               return self.conversations[session_id]       def chat(self, session_id: str, message: str) -> str:        """Chat with specific session"""        conversation = self.get_conversation(session_id)        return conversation.predict(input=message)       def clear_session(self, session_id: str):        """Clear conversation history for session"""        if session_id in self.conversations:            del self.conversations[session_id]class OllamaLangChainSystem:    """Main system integrating all components"""       def __init__(self, config: OllamaConfig):        self.config = config        self.manager = OllamaManager(config)        self.llm = None        self.rag_system = None        self.conversation_manager = None        self.tools = []        self.agent = None           def setup(self):        """Complete system setup"""        print(" Setting up Ollama + LangChain system...")               self.manager.install_ollama()        self.manager.start_server()               if not self.manager.pull_model(self.config.model_name):            print(" Failed to pull default model")            return False               self.llm = OllamaLLM(            model_name=self.config.model_name,            base_url=self.config.base_url,            temperature=self.config.temperature,            max_tokens=self.config.max_tokens,            performance_monitor=self.manager.performance_monitor        )               self.rag_system = RAGSystem(self.llm)               self.conversation_manager = ConversationManager(self.llm)               self._setup_tools()               print(" System setup complete!")        return True       def _setup_tools(self):        """Setup tools for the agent"""        search = DuckDuckGoSearchRun()               self.tools = [            Tool(                name="Search",                func=search.run,                description="Search the internet for current information"            ),            Tool(                name="RAG_Query",                func=lambda q: self.rag_system.query(q)["answer"],                description="Query loaded documents using RAG"            )        ]               self.agent = initialize_agent(            tools=self.tools,            llm=self.llm,            agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,            verbose=True        )       def chat(self, message: str, session_id: str = "default") -> str:        """Simple chat interface"""        return self.conversation_manager.chat(session_id, message)       def rag_chat(self, question: str) -> Dict[str, Any]:        """RAG-based chat"""        return self.rag_system.query(question)       def agent_chat(self, message: str) -> str:        """Agent-based chat with tools"""        return self.agent.run(message)       def switch_model(self, model_name: str) -> bool:        """Switch to different model"""        if self.manager.pull_model(model_name):            self.llm.model_name = model_name            print(f" Switched to model: {model_name}")            return True        return False       def load_documents(self, file_paths: List[str]):        """Load documents into RAG system"""        self.rag_system.add_documents(file_paths)       def get_performance_stats(self) -> Dict[str, Any]:        """Get system performance statistics"""        return self.manager.performance_monitor.get_stats()       def cleanup(self):        """Clean up resources"""        self.manager.stop_server()        print(" System cleanup complete")

We use the ConversationManager to maintain separate chat sessions, each with its memory type, either buffer-based or summary-based, allowing us to preserve or summarize context as needed. In the OllamaLangChainSystem, we integrate everything: we install and launch Ollama, pull the desired model, wrap it in a LangChain-compatible LLM, connect a RAG system, initialize chat memory, and register external tools like web search.

def main():    """Main function demonstrating the system"""       config = OllamaConfig(        model_name="llama2",        temperature=0.7,        max_tokens=2048    )       system = OllamaLangChainSystem(config)       try:        if not system.setup():            return               print("\n Testing basic chat:")        response = system.chat("Hello! How are you?")        print(f"Response: {response}")               print("\n Testing model switching:")        models = system.manager.list_models()        print(f"Available models: {models}")                      print("\n Testing agent:")        agent_response = system.agent_chat("What's the current weather like?")        print(f"Agent Response: {agent_response}")               print("\n Performance Statistics:")        stats = system.get_performance_stats()        print(json.dumps(stats, indent=2))           except KeyboardInterrupt:        print("\n Interrupted by user")    except Exception as e:        print(f" Error: {e}")    finally:        system.cleanup()def create_gradio_interface(system: OllamaLangChainSystem):    """Create a Gradio interface for easy interaction"""    try:        import gradio as gr               def chat_interface(message, history, mode):            if mode == "Basic Chat":                response = system.chat(message)            elif mode == "RAG Chat":                result = system.rag_chat(message)                response = result["answer"]            elif mode == "Agent Chat":                response = system.agent_chat(message)            else:                response = "Unknown mode"                       history.append((message, response))            return "", history               def upload_docs(files):            if files:                file_paths = [f.name for f in files]                system.load_documents(file_paths)                return f"Loaded {len(file_paths)} documents into RAG system"            return "No files uploaded"               def get_stats():            stats = system.get_performance_stats()            return json.dumps(stats, indent=2)               with gr.Blocks(title="Ollama + LangChain System") as demo:            gr.Markdown("#  Ollama + LangChain Advanced System")                       with gr.Tab("Chat"):                chatbot = gr.Chatbot()                mode = gr.Dropdown(                    ["Basic Chat", "RAG Chat", "Agent Chat"],                    value="Basic Chat",                    label="Chat Mode"                )                msg = gr.Textbox(label="Message")                clear = gr.Button("Clear")                               msg.submit(chat_interface, [msg, chatbot, mode], [msg, chatbot])                clear.click(lambda: ([], ""), outputs=[chatbot, msg])                       with gr.Tab("Document Upload"):                file_upload = gr.File(file_count="multiple", label="Upload Documents")                upload_btn = gr.Button("Upload to RAG System")                upload_status = gr.Textbox(label="Status")                               upload_btn.click(upload_docs, file_upload, upload_status)                       with gr.Tab("Performance"):                stats_btn = gr.Button("Get Performance Stats")                stats_output = gr.Textbox(label="Performance Statistics")                               stats_btn.click(get_stats, outputs=stats_output)               return demo           except ImportError:        print("Gradio not installed. Skipping interface creation.")        return Noneif __name__ == "__main__":    print(" Ollama + LangChain System for Google Colab")    print("=" * 50)       main()       # Or create a system instance for interactive use    # config = OllamaConfig(model_name="llama2")    # system = OllamaLangChainSystem(config)    # system.setup()       # # Create Gradio interface    # demo = create_gradio_interface(system)    # if demo:    #     demo.launch(share=True)  # share=True for public link

We wrap everything up in the main function to run a full demo, setting up the system, testing chat, agent tools, model listing, and performance statistics. Then, in create_gradio_interface(), we build a user-friendly Gradio app with tabs for chatting, uploading documents to the RAG system, and monitoring performance. Finally, we call main() in the __main__ block for direct Colab execution, or optionally launch the Gradio UI for interactive exploration and public sharing.

In conclusion, we have a flexible playground: we switch Ollama models, converse with buffered or summary memory, question our own documents, reach out to search when context is missing, and monitor basic resource stats to stay within Colab limits. The code is modular, allowing us to extend the tool list, tune inference options (temperature, maximum tokens, concurrency) in OllamaConfig, or adapt the RAG pipeline to larger corpora or different embedding models. We launch the Gradio app with share=True to collaborate or embed these components in our projects. We now own an extensible template for fast local LLM experimentation.


Check out the Codes. All credit for this research goes to the researchers of this project. SUBSCRIBE NOW to our AI Newsletter

The post Building a GPU-Accelerated Ollama LangChain Workflow with RAG Agents, Multi-Session Chat Performance Monitoring appeared first on MarkTechPost.

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

Ollama LangChain LLM RAG AI代理
相关文章