Skip to content

Commit 667908f

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
torch2.4.0 compat
1 parent fd33141 commit 667908f

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

python/coverage_control/algorithms/controllers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
125125
self.actions_std = self.model.actions_std.to(self.device)
126126
self.model = self.model.to(self.device)
127127
self.model.eval()
128-
# self.model = torch.compile(self.model, dynamic=True)
128+
self.model = torch.compile(self.model, dynamic=True)
129129

130130
def step(self, env):
131131
"""

python/coverage_control/io_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def load_tensor(path: str) -> torch.tensor:
7676
if not os.path.exists(path):
7777
raise FileNotFoundError(f"IOUtils::load_tensor: File not found: {path}")
7878
# Load data
79-
data = torch.load(path)
79+
data = torch.load(path, weights_only=True)
8080
# Extract tensor if data is in jit script format
8181

8282
if isinstance(data, torch.jit.ScriptModule):

python/coverage_control/nn/models/lpac.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,32 +82,32 @@ def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
8282

8383
def load_compiled_state_dict(self, model_state_dict_path: str) -> None:
8484
# remove _orig_mod from the state dict keys
85-
state_dict = torch.load(model_state_dict_path)
85+
state_dict = torch.load(model_state_dict_path, weights_only=True)
8686
new_state_dict = {}
8787
for key in state_dict.keys():
8888
new_state_dict[key.replace("_orig_mod.", "")] = state_dict[key]
89-
self.load_state_dict(new_state_dict, strict=True)
89+
self.load_state_dict(new_state_dict, strict=True, weights_only=True)
9090

9191
def load_model(self, model_state_dict_path: str) -> None:
9292
"""
9393
Load the model from the state dict
9494
"""
95-
self.load_state_dict(torch.load(model_state_dict_path), strict=True)
95+
self.load_state_dict(torch.load(model_state_dict_path), strict=True, weights_only=True)
9696

9797
def load_model_state_dict(self, model_state_dict: dict) -> None:
9898
"""
9999
Load the model from the state dict
100100
"""
101-
self.load_state_dict(model_state_dict, strict=True)
101+
self.load_state_dict(model_state_dict, strict=True, weights_only=True)
102102

103103
def load_cnn_backbone(self, model_path: str) -> None:
104104
"""
105105
Load the CNN backbone from the model path
106106
"""
107-
self.load_state_dict(torch.load(model_path).state_dict(), strict=True)
107+
self.load_state_dict(torch.load(model_path).state_dict(), strict=True, weights_only=True)
108108

109109
def load_gnn_backbone(self, model_path: str) -> None:
110110
"""
111111
Load the GNN backbone from the model path
112112
"""
113-
self.load_state_dict(torch.load(model_path).state_dict(), strict=True)
113+
self.load_state_dict(torch.load(model_path).state_dict(), strict=True, weights_only=True)

0 commit comments

Comments
 (0)