-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent_langgraph.py
More file actions
176 lines (148 loc) · 6 KB
/
agent_langgraph.py
File metadata and controls
176 lines (148 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python3
# agent_langgraph.py
# LangGraph + LangChain-OpenAI Bash agent wired to Ollama (OpenAI-compatible).
# Pre-run: force model selection from local Ollama; no defaults.
from __future__ import annotations
import os
import sys
import json
from typing import Any, Dict, List, Optional
from urllib.request import urlopen, Request
from urllib.error import URLError, HTTPError
from langgraph.prebuilt import create_react_agent # deprecation warning is fine
from langgraph.checkpoint.memory import InMemorySaver
from langchain_openai import ChatOpenAI
from bash_tool import Bash, DEFAULT_ALLOWED
# ------------------------- Prompt -------------------------
LIST_OF_ALLOWED_COMMANDS = ", ".join(sorted(DEFAULT_ALLOWED))
SYSTEM_PROMPT = f"""/think
You are a helpful Bash assistant with the ability to execute commands in the shell.
Only use the allowed commands:
{LIST_OF_ALLOWED_COMMANDS}
Never attempt dangerous or unlisted commands. Confirm intent clearly, handle failures, and list files after cd.
"""
# ------------------------- Config (no default model) -------------------------
DEFAULT_BASE_URL = "http://127.0.0.1:11434/v1"
DEFAULT_API_KEY = "" # Ollama ignores API keys; lib accepts empty string
def _env(name: str) -> Optional[str]:
return os.environ.get(name)
def _get_base_url() -> str:
url = _env("OLLAMA_BASE_URL") or _env("OPENAI_BASE_URL") or DEFAULT_BASE_URL
if not url.rstrip("/").endswith("/v1"):
url = url.rstrip("/") + "/v1"
return url
def _get_api_key() -> str:
return _env("OLLAMA_API_KEY") or _env("OPENAI_API_KEY") or DEFAULT_API_KEY
def _root_from_base_url(base_url: str) -> str:
s = base_url.rstrip("/")
if s.endswith("/v1"):
s = s[:-3]
return s.rstrip("/")
def _fetch_ollama_models(base_url: str) -> List[str]:
"""
Calls Ollama native API /api/tags to list installed models.
Returns a list like ["devstral:latest", "qwen3-coder:7b-instruct"].
"""
root = _root_from_base_url(base_url)
url = f"{root}/api/tags"
req = Request(url, headers={"Accept": "application/json"})
try:
with urlopen(req, timeout=5) as resp:
data = json.loads(resp.read().decode("utf-8"))
except (HTTPError, URLError, TimeoutError, ConnectionError) as e:
print(f"[error] Cannot reach Ollama at {url}: {e}", file=sys.stderr)
return []
except Exception as e:
print(f"[error] Unexpected error reading {url}: {e}", file=sys.stderr)
return []
models = []
for m in (data.get("models") or []):
name = m.get("name")
if isinstance(name, str) and name.strip():
models.append(name.strip())
# Dedup + sort for deterministic display
return sorted(sorted(set(models)), key=str.lower)
def _pick_model_interactively(base_url: str) -> Optional[str]:
models = _fetch_ollama_models(base_url)
if not models:
print("You don't have any available models currently in Ollama. Please download some.")
return None
print("\nAvailable local Ollama models:\n")
for idx, name in enumerate(models, start=1):
print(f"{idx}. {name}")
print("\nChoose a model by number, or 'c' to cancel.")
while True:
try:
choice = input("> ").strip().lower()
except (EOFError, KeyboardInterrupt):
return None
if choice == "c":
return None
if choice.isdigit():
i = int(choice)
if 1 <= i <= len(models):
return models[i - 1]
print("Invalid choice. Enter a valid number or 'c' to cancel.")
# ------------------------- Tool wrapper -------------------------
class ExecOnConfirm:
def __init__(self, bash: Bash):
self.bash = bash
def _confirm_execution(self, cmd: str) -> bool:
try:
return input(f" ▶️ Execute '{cmd}'? [y/N]: ").strip().lower() == "y"
except (EOFError, KeyboardInterrupt):
return False
def exec_bash_command(self, cmd: str) -> Dict[str, str]:
"""
Execute an allow-listed Bash command with human confirmation.
Args:
cmd (str): The bash command to run from the current working directory.
Returns:
dict: {"stdout": str, "stderr": str, "cwd": str} or {"error": str}
"""
if not cmd:
return {"error": "No command provided"}
if not self._confirm_execution(cmd):
return {"error": "The user declined the execution of this command."}
return self.bash.exec_bash_command(cmd)
# ------------------------- Main -------------------------
def main() -> None:
base_url = _get_base_url()
api_key = _get_api_key()
start_dir = os.environ.get("BASH_START_DIR", os.getcwd())
print(f"[config] base_url={base_url}")
# Force pre-run model selection (no default!)
selected_model = _pick_model_interactively(base_url)
if selected_model is None:
print("Cancelled.")
sys.exit(0)
print(f"[config] model={selected_model}")
bash = Bash(cwd=start_dir)
tool_wrapper = ExecOnConfirm(bash)
# langchain-openai >= 1.0: use base_url/api_key
llm = ChatOpenAI(model=selected_model, base_url=base_url, api_key=api_key)
agent = create_react_agent(
model=llm,
tools=[tool_wrapper.exec_bash_command],
prompt=SYSTEM_PROMPT,
checkpointer=InMemorySaver(),
)
print(f"[cwd: {start_dir}] Bash computer-use agent (LangGraph). Ctrl+C to exit.")
while True:
try:
user = input("[🙂] ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye.")
break
if not user:
continue
result: Dict[str, Any] = agent.invoke(
{"messages": [{"role": "user", "content": user}]},
config={"configurable": {"thread_id": "bash-session"}},
)
content = result["messages"][-1].content
if "</think>" in content:
content = content.split("</think>")[-1].strip()
print(f"\n[🤖] {content}\n")
if __name__ == "__main__":
main()