-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_utils.py
More file actions
388 lines (318 loc) · 16.3 KB
/
graph_utils.py
File metadata and controls
388 lines (318 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import json
import os
import random
import logging
from typing import Dict, List, Set, Tuple
from pathlib import Path
import fnmatch # Add this to the imports at the top
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Change from INFO to DEBUG
def load_components(components_dir: str, exclude: List[str] = []) -> Dict[str, dict]:
"""Load all component descriptors from the components directory."""
components = {}
components_path = Path(components_dir)
# Load types.json first
with open(os.path.join(components_dir, "types.json")) as f:
types = json.load(f)
logger.debug(f"[DEBUG] Loaded types: {types}")
# Recursively walk through all directories
for root, dirs, files in os.walk(components_dir):
for filename in files:
if filename.endswith('.json') and filename != 'types.json':
json_path = os.path.join(root, filename)
with open(json_path) as f:
# Get name without .json, but keep subdirectory structure
rel_path = os.path.relpath(json_path, components_dir)
# Exclude patterns are given relative to the components directory
logger.debug(f"[DEBUG] Checking {rel_path} against {exclude}")
if any(fnmatch.fnmatch(rel_path, pattern) for pattern in exclude):
logger.debug(f"[DEBUG] Excluding {rel_path}")
continue
name = rel_path[:-5] # Remove .json
descriptor = json.load(f)
components[name] = descriptor
logger.debug(f"[DEBUG] Loaded component {name}: {descriptor}")
return components
def find_components_producing_type(type_name: str, components: Dict[str, dict]) -> List[str]:
"""Find all components that have an output of the given type."""
matching_components = []
for name, descriptor in components.items():
for output in descriptor.get("outputs", []):
if output["type"] == type_name:
matching_components.append(name)
break
logger.debug(f"[DEBUG] Components producing {type_name}: {matching_components}")
return matching_components
def bridge(td_proxy,
input_handles: List[int],
output_handles: List[int],
reuse_weight: float = 0.7,
exclude_components: List[str] = [],
include_io_config: bool = True):
"""
Stochastically generate a network connecting input nodes to output nodes.
Each handle represents a node in the TouchDesigner network.
Types are determined from component descriptors.
The handles from the I/O config are automatically added to the input and output lists when `include_io_config` is True.
Args:
td_proxy: The TouchDesigner proxy object.
input_handles: A list of input handles.
output_handles: A list of output handles.
reuse_weight: The weight of the reuse operation.
exclude_components: List of component names or glob patterns to exclude (e.g. ["wrapped/*", "audio_*"])
include_io_config: Whether to include handles from the IO config
"""
logger.debug("Starting bridge with inputs=%s, outputs=%s", input_handles, output_handles)
# Get IO configuration
io_config = td_proxy.get_io_handles()
logger.debug("IO config: %s", io_config)
# Get all component descriptors
components = load_components("/Users/kevin/Projects/graph_explorer/components",
exclude=exclude_components)
logger.debug("Available components: %s", components)
# Log the input node descriptors
for handle in input_handles:
desc = td_proxy.get_op_descriptor(handle)
logger.debug("Input handle %d descriptor: %s", handle, desc)
# When looking for components that can produce a type
producers = find_components_producing_type('waveform', components)
logger.debug("Found producers for 'waveform': %s", producers)
if include_io_config:
io_config = td_proxy.get_io_handles()
input_handles.extend(io_config["inputs"])
output_handles.extend(io_config["outputs"])
# Deduplicate the input and output handles
input_handles = list(set(input_handles))
output_handles = list(set(output_handles))
# Get types for input and output nodes from their descriptors
# Each entry is (handle, index, type)
input_nodes = []
for handle in input_handles:
descriptor = td_proxy.get_op_descriptor(handle)
if descriptor and "outputs" in descriptor:
for idx, output in enumerate(descriptor["outputs"]):
output_type = output["type"]
input_nodes.append((handle, idx, output_type))
logger.debug("Input node %d output[%d] provides type %s", handle, idx, output_type)
else:
raise ValueError(f"No descriptor found for input handle {handle}")
# Keep track of unsatisfied outputs we need to connect
outputs_to_satisfy = []
for handle in output_handles:
descriptor = td_proxy.get_op_descriptor(handle)
if descriptor and "inputs" in descriptor:
for idx, input_desc in enumerate(descriptor["inputs"]):
input_type = input_desc["type"]
outputs_to_satisfy.append((handle, idx, input_type))
logger.debug("Output node %d input[%d] requires type %s", handle, idx, input_type)
else:
raise ValueError(f"No descriptor found for output handle {handle}")
# Keep track of available inputs we can connect to
# available_inputs = [(handle, idx, type_name)
# for handle, idx, type_name in input_nodes]
# Keep track of all nodes we create
created_nodes = []
# Keep track of available outputs by type
available_outputs = {} # type -> List[(handle, index)]
for handle, idx, type_name in input_nodes:
if type_name not in available_outputs:
available_outputs[type_name] = []
available_outputs[type_name].append((handle, idx))
# Keep track of node ordering to prevent cycles
node_order = {}
current_order = 0
# Initialize output nodes with highest order
for handle in output_handles:
node_order[handle] = current_order
current_order += 1
def can_connect_without_cycle(source_handle: int, target_handle: int) -> bool:
"""Check if connecting source to target would create a cycle."""
nonlocal current_order
# If target isn't in ordering yet, assign it current_order
if target_handle not in node_order:
node_order[target_handle] = current_order
current_order += 1
# If source isn't in ordering yet, assign it an order before target
if source_handle not in node_order:
node_order[source_handle] = node_order[target_handle] - 1
result = node_order[source_handle] < node_order[target_handle]
logger.debug(
f"[DEBUG] Cycle check: source={source_handle}(order={node_order.get(source_handle, 'None')}), "
f"target={target_handle}(order={node_order.get(target_handle, 'None')}), result={result}"
)
return result
while outputs_to_satisfy:
output_handle, output_index, required_type = outputs_to_satisfy.pop(0)
logger.debug(
f"[DEBUG] Trying to satisfy output {output_handle}:{output_index} requiring type {required_type}"
)
# Try to find an existing output of the required type
logger.debug(f"[DEBUG] Available outputs by type: {available_outputs}")
valid_existing_outputs = [(h, idx)
for h, idx in available_outputs.get(required_type, [])
if can_connect_without_cycle(h, output_handle)]
logger.debug(
f"[DEBUG] Valid existing outputs for {required_type}: {valid_existing_outputs}")
rand_val = random.random()
use_existing = valid_existing_outputs and rand_val < reuse_weight
logger.debug(
f"[DEBUG] Random value: {rand_val}, REUSE_WEIGHT: {reuse_weight}, use_existing: {use_existing}"
)
# Create a new component
producer_components = find_components_producing_type(required_type, components)
logger.debug(
f"[DEBUG] Found producer components for {required_type}: {producer_components}")
if use_existing or not producer_components and len(valid_existing_outputs):
# Use an existing output
source_handle, source_index = random.choice(valid_existing_outputs)
logger.debug(
f"[DEBUG] Reusing existing output {source_handle}:{source_index} of type {required_type}"
)
td_proxy.connect(source_handle, source_index, output_handle, output_index)
logger.debug(
f"[DEBUG] Connected {source_handle}:{source_index} -> {output_handle}:{output_index}"
)
else:
if not producer_components:
raise ValueError(f"No components found that can produce type {required_type}")
chosen_component = random.choice(producer_components)
logger.debug(f"[DEBUG] Chose component {chosen_component} to produce {required_type}")
new_handle = td_proxy.load(chosen_component)
created_nodes.append(new_handle)
logger.debug(f"[DEBUG] Created component with handle {new_handle}")
# Connect its output to our target
td_proxy.connect(new_handle, 0, output_handle, output_index)
logger.debug(f"[DEBUG] Connected {new_handle}:0 -> {output_handle}:{output_index}")
# Register all outputs as available
component_desc = components[chosen_component]
for i, output_desc in enumerate(component_desc.get("outputs", [])):
if i != 0: # Skip the output we just used
output_type = output_desc["type"]
if output_type not in available_outputs:
available_outputs[output_type] = []
available_outputs[output_type].append((new_handle, i))
logger.debug(
f"[DEBUG] Registered available output {new_handle}:{i} of type {output_type}"
)
# Add its inputs to our list of outputs we need to satisfy
for i, input_desc in enumerate(component_desc.get("inputs", [])):
# First check if we have an available input of matching type
# matching_input_idx = None
# for j, (in_handle, in_index,
# in_type) in enumerate(available_inputs):
# if in_type == input_desc["type"]:
# matching_input_idx = j
# break
# if matching_input_idx is not None:
# # Use this available input
# in_handle, in_index, _ = available_inputs.pop(
# matching_input_idx)
# td_proxy.connect(in_handle, in_index, new_handle, i)
# logger.debug(
# f"[DEBUG] Connected available input {in_handle}:{in_index} -> {new_handle}:{i}"
# )
# else:
# Add to outputs we need to satisfy
outputs_to_satisfy.append((new_handle, i, input_desc["type"]))
logger.debug(
f"[DEBUG] Added new output to satisfy: {new_handle}:{i} type {input_desc['type']}"
)
#if available_inputs:
# logger.warning(
# f"[WARNING] Some inputs were not used: {available_inputs}")
# Return all nodes involved in the bridge
return created_nodes
def topo_sort_handles(td_proxy, handles):
logger.debug("Starting topological sort")
# Get the connection information for each handle
connection_info = {}
# First pass: collect all handles including referenced inputs
all_handles = set(handles)
for handle in handles:
logger.debug(f"[DEBUG] Getting connectors for handle {handle}")
connectors = td_proxy.get_op_connectors(handle)
logger.debug(f"[DEBUG] Raw connector info: {connectors}")
# Look at the actual connections in the input connectors
for in_conn in connectors["in"]:
if in_conn["targets"]: # Only look at connectors that have actual connections
for target_handle, _ in in_conn["targets"]:
if target_handle is not None:
all_handles.add(target_handle)
# Second pass: get connection info for all handles
for handle in all_handles:
connectors = td_proxy.get_op_connectors(handle)
in_connections = []
for in_conn in connectors["in"]:
if in_conn["targets"]: # Only count connectors that have actual connections
in_connections.append(in_conn)
out_connections = []
for out_conn in connectors["out"]:
if out_conn["targets"]: # Only count connectors that have actual connections
out_connections.append(out_conn)
connection_info[handle] = {
"in_connectors": in_connections,
"out_connectors": out_connections
}
logger.debug(
f"[DEBUG] Handle {handle} has {len(in_connections)} active inputs and {len(out_connections)} active outputs"
)
# Kahn's algorithm for topological sort
# Count incoming edges for each node
in_degree = {handle: len(connection_info[handle]["in_connectors"]) for handle in all_handles}
# Find all nodes with no incoming edges
queue = [handle for handle in all_handles if in_degree[handle] == 0]
logger.debug(f"[DEBUG] Starting nodes with no incoming edges: {queue}")
sorted_handles = []
while queue:
current = queue.pop(0) # Get next node with no incoming edges
sorted_handles.append(current)
logger.debug(f"[DEBUG] Adding node {current} to sorted list")
# Remove edges from current node to its targets
for out_conn in connection_info[current]["out_connectors"]:
for target_handle, _ in out_conn["targets"]:
if target_handle in in_degree: # Only process nodes we're tracking
in_degree[target_handle] -= 1
logger.debug(
f"[DEBUG] Reduced in-degree of {target_handle} to {in_degree[target_handle]}"
)
if in_degree[target_handle] == 0:
queue.append(target_handle)
logger.debug(
f"[DEBUG] Node {target_handle} has no more incoming edges, adding to queue"
)
if len(sorted_handles) != len(all_handles):
raise ValueError("Graph has cycles")
logger.debug(f"[DEBUG] Topological sort complete. Order: {sorted_handles}")
return sorted_handles
def layout_nodes(td_proxy, sorted_handles):
logger.debug("Starting node layout")
# Get the geometry for each handle
geometry = {}
total_width = 0
max_height = 0
MARGIN = 20 # Units between nodes
logger.debug("[DEBUG] Collecting geometry information")
for handle in sorted_handles:
logger.debug(f"[DEBUG] Getting geometry for handle {handle}")
x, y, w, h = td_proxy.get_op_node_geometry(handle)
geometry[handle] = (x, y, w, h)
total_width += w
max_height = max(max_height, h)
logger.debug(f"[DEBUG] Node {handle} geometry: x={x}, y={y}, w={w}, h={h}")
# Calculate total width including margins
total_width += MARGIN * (len(sorted_handles) - 1)
# Calculate starting x position to center around 0
start_x = -total_width / 2
# Set all of the node X coordinates according to the sorted order
current_x = start_x
logger.debug("[DEBUG] Positioning nodes")
for handle in sorted_handles:
x, y, w, h = geometry[handle]
# Center vertically at y=0
center_y = -h / 2
logger.debug(f"[DEBUG] Setting position for handle {handle} to x={current_x}, y={center_y}")
td_proxy.set_op_attribute(handle, "nodeX", current_x)
td_proxy.set_op_attribute(handle, "nodeY", center_y)
# Move to next position including margin
current_x += w + MARGIN