forked from OpenHands/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
112 lines (93 loc) · 4 KB
/
utils.py
File metadata and controls
112 lines (93 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import os
import re
import string
import zipfile
import gdown
import requests
def download_data(dir):
data_path = os.path.join(dir, 'data/external_corpus')
if os.path.exists(data_path):
return data_path
url = 'https://drive.google.com/uc?id=1zRbHzPW2x4dDcfmphBWlan8cxUCRNmqk'
zip_path = os.path.join(dir, 'data.zip')
gdown.download(url, zip_path, quiet=False)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(os.path.join(dir, 'data'))
if os.path.exists(zip_path):
os.remove(zip_path)
return data_path
def download_tools(dir, wolfram_alpha_appid='YOUR_WOLFRAMALPHA_APPID'):
tool_path = os.path.join(dir, 'tools')
if os.path.exists(tool_path):
return tool_path
os.mkdir(tool_path)
tools = [
'code/sql_interpreter.py',
'graph/graphtools.py',
'math/calculator.py',
'table/mysql_db_create.py',
'table/tabtools.py',
'text/agenda_retriever.py',
'text/scirex_retriever.py',
]
for tool in tools:
url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/benchmark/ReAct/code/tools/{tool}'
response = requests.get(url)
output_file = os.path.join(tool_path, tool.split('/')[1])
with open(output_file, 'wb') as f:
f.write(response.content)
with open(os.path.join(tool_path, 'calculator.py'), 'r') as f:
content = f.read()
new_content = content.replace('YOUR_WOLFRAMALPHA_APPID', wolfram_alpha_appid)
with open(os.path.join(tool_path, 'calculator.py'), 'w') as f:
f.write(new_content)
with open(os.path.join(tool_path, 'agenda_retriever.py'), 'r') as f:
content = f.read()
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
with open(os.path.join(tool_path, 'agenda_retriever.py'), 'w') as f:
f.write(new_content)
with open(os.path.join(tool_path, 'mysql_db_create.py'), 'r') as f:
content = f.read()
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
with open(os.path.join(tool_path, 'mysql_db_create.py'), 'w') as f:
f.write(new_content)
with open(os.path.join(tool_path, 'scirex_retriever.py'), 'r') as f:
content = f.read()
new_content = content.replace('/<YOUR_OWN_PATH>/ToolQA/', '')
with open(os.path.join(tool_path, 'scirex_retriever.py'), 'w') as f:
f.write(new_content)
def get_data(dataset, hardness):
data = []
url = f'https://raw.githubusercontent.com/night-chen/ToolQA/main/data/questions/{hardness}/{dataset}-{hardness}.jsonl'
url = requests.get(url)
if url.status_code == 200:
lines = url.text.splitlines()
for line in lines:
data.append(json.loads(line))
return data
REACT_INSTRUCTION = """Use tools in the tools directory to solve the task: {question}
You could use all tools which are under the tools/ directory and all the data under the data/ directory.
When you think you finished the task, respond with `Finish[answer]` where you include your answer in `[]`.
IMPORTANT: Make sure that in your final answer, you should not print any additional text/instructions other than the actual answer, which should be a word or a simple phrase.
"""
def encode_question(question):
return REACT_INSTRUCTION.format(question=question)
# imported from https://github.com/night-chen/ToolQA/tree/main/benchmark/ReAct/code/agents_chatgpt.py
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the|usd)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def eval_answer(pred, answer):
pattern = r'Finish\[(.*?)\]'
match = re.search(pattern, pred)
if match:
pred = match.group(1)
return normalize_answer(pred) == normalize_answer(answer)