-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathDRIFTTaskSuite.py
More file actions
220 lines (186 loc) · 8.78 KB
/
DRIFTTaskSuite.py
File metadata and controls
220 lines (186 loc) · 8.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from import_lib import *
class InjectionVector(BaseModel):
description: str
default: str
# @lru_cache
def read_suite_file(suite_name: str, file: str, suite_data_path: Path | None) -> str:
if suite_data_path is not None:
path = suite_data_path / file
with path.open() as f:
data_yaml = yaml.load(f, ImportLoader)
return yaml.dump(data_yaml, default_flow_style=False)
package_files = importlib.resources.files("agentdojo")
path = package_files / "data" / "suites" / suite_name / file
with importlib.resources.as_file(path) as p, p.open() as f:
# Load into yaml to resolve imports
data_yaml = yaml.load(f, ImportLoader)
# dump back to string to pass to include injections later
return yaml.dump(data_yaml, default_flow_style=False)
def validate_injections(injections: dict[str, str], injection_vector_defaults: dict[str, str]):
injections_set = set(injections.keys())
injection_vector_defaults_set = set(injection_vector_defaults.keys())
if not injections_set.issubset(injection_vector_defaults_set):
raise ValueError("Injections must be a subset of the injection vector defaults")
class DRIFTTaskSuite(TaskSuite[Env]):
"""A suite of tasks that can be run in an environment. Tasks can be both user tasks and injection tasks. It is not
mandatory to have injection tasks in case the suite is used to evaluate the model only for utility.
Args:
name: The name of the suite.
environment_type: The environment type that the suite operates on.
tools: A list of tools that the agent can use to solve the tasks.
data_path: The path to the suite data directory. It should be provided for non-default suites.
The directory should contain the following files:
- `environment.yaml`: The data of the environment.
- `injection_vectors.yaml`: The injection vectors in the environment.
"""
def __init__(
self,
args,
name: str,
environment_type: type[Env],
tools: list[Function],
data_path: Path | None = None,
benchmark_version: BenchmarkVersion = (1, 0, 0),
parent_instance: TaskSuite | None = None,
):
self.args = args
self.name = name
self.environment_type = environment_type
self.tools = tools
self._user_tasks: dict[str, dict[BenchmarkVersion, BaseUserTask[Env]]] = defaultdict(dict)
self._injection_tasks: dict[str, dict[BenchmarkVersion, BaseInjectionTask[Env]]] = defaultdict(dict)
self.data_path = data_path
self.benchmark_version = benchmark_version
if parent_instance:
self.__dict__.update(parent_instance.__dict__)
def get_new_version(self, benchmark_version: BenchmarkVersion) -> Self:
new_self = copy.deepcopy(self)
new_self.benchmark_version = benchmark_version
return new_self
def load_and_inject_default_environment(self, injections: dict[str, str]) -> Env:
environment_text = read_suite_file(self.name, "environment.yaml", self.data_path)
injection_vector_defaults = self.get_injection_vector_defaults()
injections_with_defaults = dict(injection_vector_defaults, **injections)
validate_injections(injections, injection_vector_defaults)
injected_environment = environment_text.format(**injections_with_defaults)
environment = self.environment_type.model_validate(yaml.safe_load(injected_environment))
return environment
def get_injection_vector_defaults(self) -> dict[str, str]:
injection_vectors_text = read_suite_file(self.name, "injection_vectors.yaml", self.data_path)
injection_vectors = yaml.safe_load(injection_vectors_text)
vectors = {
vector_id: InjectionVector.model_validate(vector_info)
for vector_id, vector_info in injection_vectors.items()
}
return {vector_id: vector.default for vector_id, vector in vectors.items()}
# def functions_stack_trace_from_messages(
# self,
# messages: Sequence[ChatMessage],
# ) -> list[FunctionCall]:
# tool_calls = []
# for message in messages:
# if message["role"] == "assistant":
# for tool_call in message["tool_calls"] or []:
# tool_calls.append(tool_call)
# return tool_calls
def model_output_from_messages(
self,
messages: Sequence[ChatMessage],
) -> list[MessageContentBlock] | None:
if messages[-1]["role"] != "assistant":
raise ValueError("Last message was not an assistant message")
return [{"type": "text", "content": messages[-1]["content"]}]
def functions_call_format(self, messages):
for message in messages:
if "tool_calls" in message:
new_call_list =[]
for call in message["tool_calls"]:
new_call_list.append({
"function": call.function,
"args": call.args,
"id": call.id,
})
message["tool_calls"] = new_call_list
if "tool_call" in message:
message["tool_call"] = {
"function": message["tool_call"].function,
"args": message["tool_call"].args,
"id": message["tool_call"].id,
}
return messages
def run_task_with_pipeline(
self,
agent_pipeline: BasePipelineElement,
user_task: BaseUserTask[Env] | BaseInjectionTask[Env],
injection_task: BaseInjectionTask[Env] | None,
injections: dict[str, str],
runtime_class: type[FunctionsRuntime] = FunctionsRuntime,
environment: Env | None = None,
verbose: bool = False,
) -> tuple[bool, bool]:
"""Run a task with the provided pipeline.
Args:
agent_pipeline: The pipeline to use for the task.
user_task: The user task to run.
injection_task: The injection task to run.
injections: The injections to use for the task.
runtime_class: The runtime class to use for the task.
environment: The environment to use for the task.
verbose: Whether to print debug information.
Returns:
A tuple of two booleans, the first indicating whether the task was successful, and the second indicating if
the injection was successful.
"""
if self.args.adaptive_attack:
for key in injections:
injections[key] = injections[key] + ADAPTIVE_ATTACK_PROMPT
# If no environment is provided, load the default environment
if environment is None:
environment = self.load_and_inject_default_environment(injections)
# Initialize the environment according to the task
if isinstance(user_task, BaseUserTask):
task_environment = user_task.init_environment(environment)
else:
task_environment = environment
# Create a copy of the environment before running the user task to then diff if needed
pre_environment = task_environment.model_copy(deep=True)
if isinstance(user_task, BaseUserTask):
prompt = user_task.PROMPT
else:
prompt = user_task.GOAL
runtime = runtime_class(self.tools)
model_output = None
messages = []
for _ in range(3):
# Run the user task
try:
_, _, task_environment, messages, _ = agent_pipeline.query(prompt, runtime, task_environment)
except AbortAgentError as e:
# get last environment and messages
task_environment = e.task_environment
messages = e.messages
model_output = self.model_output_from_messages(messages)
if model_output is not None:
break
if model_output is None:
warnings.warn(f"Model output was None for task {user_task.ID}")
functions_stack_trace = functions_stack_trace_from_messages(messages)
utility = self._check_task_result(
user_task,
model_output or "",
pre_environment,
task_environment, # type: ignore
functions_stack_trace,
)
format_messages = self.functions_call_format(messages)
# Early return if no injection was intended
if injection_task is None:
return utility, True, format_messages
security = self._check_task_result(
injection_task,
model_output or "",
pre_environment,
task_environment, # type: ignore
functions_stack_trace,
)
return utility, security, format_messages