-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbash_tool.py
More file actions
188 lines (169 loc) · 6.92 KB
/
bash_tool.py
File metadata and controls
188 lines (169 loc) · 6.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/env python3
# bash_tool.py
# Safe Bash executor with an allowlist and working-directory tracking.
# Works on Linux/macOS/WSL. For Windows PowerShell/CMD, adapt the executable and command wrapping.
from __future__ import annotations
import os
import shlex
import subprocess
from typing import Any, Dict, Iterable, List, Set
# DEFAULT_ALLOWED: Set[str] = {
# # navigation
# "pwd", "ls", "cd",
# # files & viewing
# "cat", "head", "tail", "less", "more",
# "stat", "wc", "cut", "sort", "uniq", "nl",
# "echo", "printf", "grep", "egrep", "fgrep", "sed", "awk",
# "tr", "tee",
# # filesystem (non-destructive)
# "touch", "mkdir", "mktemp",
# # system info
# "df", "du", "free", "uptime", "uname", "whoami", "id", "env",
# # net
# "ping", "curl", "wget",
# # archives
# "tar", "zip", "unzip",
# # basic process queries
# "ps", "top", "htop",
# }
DEFAULT_ALLOWED: Set[str] = (
{
# Bash built-ins (bash 4/5+)
":", ".", "alias", "bg", "bind", "break", "builtin", "caller", "cd", "command",
"compgen", "complete", "compopt", "continue", "declare", "dirs", "disown", "echo",
"enable", "eval", "exec", "exit", "export", "fc", "fg", "getopts", "hash", "help",
"history", "jobs", "kill", "let", "local", "logout", "mapfile", "readarray",
"popd", "printf", "pushd", "pwd", "read", "readonly", "return", "set", "shift",
"shopt", "source", "suspend", "test", "times", "trap", "type", "typeset",
"ulimit", "umask", "unalias", "unset", "wait", "rm", "rmdir"
}
|
{
name
for d in os.environ.get("PATH", "").split(os.pathsep) if d and os.path.isdir(d)
for name in os.listdir(d)
if os.path.isfile(os.path.join(d, name)) and os.access(os.path.join(d, name), os.X_OK)
}
) - {"shutdown", "reboot", "poweroff"}
# Optional: honor additional excludes from env (comma-separated)
DEFAULT_ALLOWED -= {x.strip() for x in os.environ.get("BASH_TOOL_EXCLUDE", "").split(",") if x.strip()}
class Bash:
"""
A restricted Bash tool.
- Keeps track of current working directory.
- Enforces an allowlist of *base* commands (first token of each simple command).
- Returns stdout/stderr and updated cwd.
"""
def __init__(self, cwd: str | None = None, allowed_commands: Iterable[str] | None = None,
bash_executable: str = "/bin/bash"):
self.cwd = os.path.abspath(cwd or os.getcwd())
self._allowed_commands: Set[str] = set(allowed_commands or DEFAULT_ALLOWED)
self._bash_executable = bash_executable
# ---------- Public API (for tool-calling) ----------
def exec_bash_command(self, cmd: str) -> Dict[str, str]:
"""
Execute a bash command (already confirmed by caller).
Checks allowlist on *every* simple command within 'cmd'.
Returns: {"stdout": ..., "stderr": ..., "cwd": ...} or {"error": "..."} on validation failure.
"""
cmd = (cmd or "").strip()
if not cmd:
return {"error": "No command was provided"}
if not self._is_allowlisted(cmd):
return {"error": "Parts of this command were not in the allowlist."}
return self._run_bash_command(cmd)
def to_json_schema(self) -> Dict[str, Any]:
"""
JSON Schema describing the function for OpenAI-compatible tool calling.
"""
return {
"type": "function",
"function": {
"name": "exec_bash_command",
"description": "Execute a bash command and return stdout/stderr and the working directory",
"parameters": {
"type": "object",
"properties": {
"cmd": {
"type": "string",
"description": "The bash command to execute from the current working directory"
}
},
"required": ["cmd"],
},
},
}
# ---------- Internals ----------
def _run_bash_command(self, cmd: str) -> Dict[str, str]:
"""
Wraps the command to capture new CWD after a sequence (to track 'cd').
"""
separator = "__END_OF_COMMAND__"
wrapped = f"{cmd}; echo {separator}; pwd"
stdout = ""
stderr = ""
new_cwd = self.cwd
try:
result = subprocess.run(
wrapped,
shell=True,
cwd=self.cwd,
capture_output=True,
text=True,
executable=self._bash_executable,
)
stderr = (result.stderr or "").rstrip()
# Split stdout into command output and the trailing 'pwd'
raw_out = (result.stdout or "")
split = raw_out.split(separator)
cmd_out = split[0].strip()
if not cmd_out and not stderr:
cmd_out = "Command executed successfully, without any output."
# Update cwd from the last line after separator
if len(split) > 1:
tail = split[-1].strip()
# 'pwd' output is the final line
lines = [ln for ln in tail.splitlines() if ln.strip()]
if lines:
new_cwd = lines[-1].strip()
if os.path.isdir(new_cwd):
self.cwd = new_cwd
stdout = cmd_out
except Exception as e:
stdout = ""
stderr = str(e)
return {"stdout": stdout, "stderr": stderr, "cwd": new_cwd}
def _is_allowlisted(self, cmd: str) -> bool:
"""
Parse 'cmd' into simple commands (split on '&&', ';', '|', '||') and ensure
each base command is in the allowlist. Redirections are allowed.
"""
# Split on common Bash separators; keep it simple and conservative.
# This does NOT attempt to fully parse complex Bash; it's a guardrail.
separators = ["&&", "||", "|", ";"]
parts = [cmd]
for sep in separators:
parts = sum([p.split(sep) for p in parts], [])
for part in parts:
base = self._extract_base_command(part)
if base and base not in self._allowed_commands:
return False
return True
@staticmethod
def _extract_base_command(segment: str) -> str:
"""
Get the first token that is not a shell built-in symbol or redirection.
"""
segment = segment.strip()
if not segment:
return ""
# Remove leading redirections like ">> file", "< file"
tokens = [t for t in shlex.split(segment) if t not in {">", ">>", "<", "2>", "2>>", "1>", "1>>"}]
if not tokens:
return ""
token = tokens[0]
# Ignore command grouping tokens
if token in {"(", ")", "{", "}"}:
return ""
# Builtins like 'cd' are handled by allowlist too
return token