-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtest_plot.py
More file actions
140 lines (119 loc) · 4.87 KB
/
test_plot.py
File metadata and controls
140 lines (119 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import os
import unittest
from unittest.mock import patch, MagicMock
import networkx as nx
from pydantic import ValidationError
from python_workflow_definition.plot import plot
from python_workflow_definition.shared import (
NODES_LABEL,
EDGES_LABEL,
SOURCE_LABEL,
TARGET_LABEL,
SOURCE_PORT_LABEL,
TARGET_PORT_LABEL,
)
class TestPlot(unittest.TestCase):
def setUp(self):
self.test_file = "test_workflow.json"
self.workflow_data = {
"version": "0.0.1",
NODES_LABEL: [
{"id": 1, "name": "Node 1", "type": "function", "value": "a.b"},
{"id": 2, "name": "Node 2", "type": "function", "value": "c.d"},
{"id": 3, "name": "Node 3", "type": "function", "value": "e.f"},
],
EDGES_LABEL: [
{
SOURCE_LABEL: 1,
TARGET_LABEL: 2,
SOURCE_PORT_LABEL: "out1",
TARGET_PORT_LABEL: "in1",
},
{
SOURCE_LABEL: 2,
TARGET_LABEL: 3,
SOURCE_PORT_LABEL: "out2",
TARGET_PORT_LABEL: "in2",
},
{
SOURCE_LABEL: 1,
TARGET_LABEL: 3,
SOURCE_PORT_LABEL: None,
TARGET_PORT_LABEL: "in3",
},
],
}
with open(self.test_file, "w") as f:
json.dump(self.workflow_data, f)
def tearDown(self):
if os.path.exists(self.test_file):
os.remove(self.test_file)
@patch("python_workflow_definition.plot.display")
@patch("python_workflow_definition.plot.SVG")
@patch("networkx.nx_agraph.to_agraph")
def test_plot(self, mock_to_agraph, mock_svg, mock_display):
mock_agraph = MagicMock()
mock_to_agraph.return_value = mock_agraph
mock_agraph.draw.return_value = "<svg></svg>"
plot(self.test_file)
self.assertEqual(1, mock_to_agraph.call_count)
graph = mock_to_agraph.call_args[0][0]
self.assertIsInstance(graph, nx.DiGraph)
self.assertCountEqual(["1", "2", "3"], graph.nodes)
self.assertEqual("a.b", graph.nodes["1"]["name"])
self.assertEqual("c.d", graph.nodes["2"]["name"])
self.assertEqual("e.f", graph.nodes["3"]["name"])
self.assertCountEqual([("1", "2"), ("2", "3"), ("1", "3")], graph.edges)
edge_n1_n2_data = graph.get_edge_data("1", "2")
self.assertIn("label", edge_n1_n2_data)
self.assertEqual("in1=result[out1]", edge_n1_n2_data["label"])
edge_n1_n3_data = graph.get_edge_data("1", "3")
self.assertIn("label", edge_n1_n3_data)
self.assertEqual("in3", edge_n1_n3_data["label"])
mock_svg.assert_called_once_with("<svg></svg>")
mock_display.assert_called_once()
@patch("python_workflow_definition.plot.display")
@patch("python_workflow_definition.plot.SVG")
@patch("networkx.nx_agraph.to_agraph")
def test_plot_multiple_edges_same_source(self, mock_to_agraph, mock_svg, mock_display):
self.workflow_data[EDGES_LABEL].append(
{
SOURCE_LABEL: 1,
TARGET_LABEL: 2,
SOURCE_PORT_LABEL: "out2",
TARGET_PORT_LABEL: "in2",
}
)
with open(self.test_file, "w") as f:
json.dump(self.workflow_data, f)
mock_agraph = MagicMock()
mock_to_agraph.return_value = mock_agraph
mock_agraph.draw.return_value = "<svg></svg>"
plot(self.test_file)
self.assertEqual(1, mock_to_agraph.call_count)
graph = mock_to_agraph.call_args[0][0]
self.assertIsInstance(graph, nx.DiGraph)
# This assertion is correct due to the logic in `plot.py`. The function
# groups all connections between a single source node and a single target
# node. If it finds more than one connection (e.g., from different
# source ports to different target ports), it creates a single,
# unlabeled edge in the graph to represent the multiple connections.
edge_n1_n2_data = graph.get_edge_data("1", "2")
self.assertNotIn("label", edge_n1_n2_data)
def test_plot_file_not_found(self):
with self.assertRaises(FileNotFoundError):
plot("non_existent_file.json")
def test_plot_invalid_json(self):
with open(self.test_file, "w") as f:
f.write("{'invalid': 'json'")
with self.assertRaises(ValidationError):
plot(self.test_file)
def test_plot_missing_keys(self):
invalid_data = {"version": "0.0.1", "edges": []}
with open(self.test_file, "w") as f:
json.dump(invalid_data, f)
with self.assertRaises(ValidationError):
plot(self.test_file)
if __name__ == "__main__":
unittest.main()