Skip to content

Commit 1959556

Browse files
committed
Added uma script.
1 parent 8a0e3f3 commit 1959556

File tree

4 files changed

+281
-0
lines changed

4 files changed

+281
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.DS_Store
12
# Byte-compiled / optimized / DLL files
23
__pycache__/
34
*.py[cod]

.idea/.gitignore

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Binary file not shown.

python/uma/uma.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
"""
2+
This module provides functions to relax molecules using the UMA (Universal Molecular Architecture) model from FAIRChem.
3+
Github: https://github.com/facebookresearch/fairchem -> UMA installation and first steps
4+
Other resources:
5+
Demo: https://facebook-fairchem-uma-demo.hf.space/
6+
Paper: https://arxiv.org/abs/2506.23971
7+
Huggingface: https://huggingface.co/facebook/UMA
8+
"""
9+
import warnings
10+
from datetime import datetime
11+
from typing import Union
12+
from copy import deepcopy
13+
from collections import defaultdict
14+
import tempfile
15+
import numpy as np
16+
import torch
17+
from pathlib import Path
18+
import ase
19+
from ase.io import read
20+
from ase.optimize import LBFGS
21+
from fairchem.core import FAIRChemCalculator, pretrained_mlip
22+
from fairchem.core.units.mlip_unit.api.inference import InferenceSettings
23+
from ase.build import molecule
24+
25+
synonyms = {'umas': 'uma-s-1p1', 'umam': 'uma-m-1p1'}
26+
27+
# Custom inference settings designed to speed up inference and minimise VRAM usage. The only difference to the default settings is tf32=True.
28+
def inference_settings_speedy():
29+
return InferenceSettings(
30+
tf32=True,
31+
activation_checkpointing=True,
32+
merge_mole=False,
33+
compile=False,
34+
wigner_cuda=False,
35+
external_graph_gen=False,
36+
internal_graph_gen_version=2,
37+
)
38+
39+
cached_mlips = defaultdict(dict)
40+
def load_mlip(method: str, device: str):
41+
print(f'Loading MLIP `{method}` on device `{device}`...')
42+
try:
43+
predict_unit = pretrained_mlip.get_predict_unit(method, device=device, inference_settings=inference_settings_speedy())
44+
except KeyError:
45+
try:
46+
predict_unit = pretrained_mlip.get_predict_unit(synonyms[method], device=device, inference_settings=inference_settings_speedy())
47+
except KeyError:
48+
raise KeyError(f'Predictor `{method}` not found.')
49+
cached_mlips[method][device] = predict_unit
50+
51+
return predict_unit
52+
53+
def get_mlip(method, device):
54+
try:
55+
return cached_mlips[method][device]
56+
except KeyError:
57+
load_mlip(method, device)
58+
return cached_mlips[method][device]
59+
60+
def uma_get_hessian(calc: FAIRChemCalculator, atoms, vmap: bool=False):
61+
"""
62+
Get the Hessian matrix for the given atomic structure.
63+
Args:
64+
atoms (Atoms): The atomic structure to calculate the Hessian for.
65+
vmap (bool): Whether to use vectorized mapping for Hessian calculation. This can speed up the calculation but for medium-sized systems (TMCs) it uses way too much RAM. Calculating the Hessian without vmap for a TMC with the small UMA model took around 6 minutes on Mac (with possible background parallelization).
66+
Returns:
67+
np.ndarray: The Hessian matrix.
68+
"""
69+
from fairchem.core.datasets import data_list_collater
70+
import torch
71+
from torch.autograd import grad
72+
# Turn on create_graph for the first derivative
73+
calc.predictor.model.module.output_heads['energyandforcehead'].head.training = True
74+
75+
# Convert using the current a2g object
76+
data_list = [calc.a2g(atoms)]# for atoms in atoms_list]
77+
78+
# Batch and predict
79+
batch = data_list_collater(data_list, otf_graph=True)
80+
pred = calc.predictor.predict(batch)
81+
82+
# Get the forces and positions
83+
positions = batch.pos
84+
forces = pred["forces"].flatten()
85+
86+
# Calculate the Hessian using autograd
87+
if vmap:
88+
hessian = torch.vmap(
89+
lambda vec: grad(
90+
-forces,
91+
positions,
92+
grad_outputs=vec,
93+
retain_graph=True,
94+
)[0],
95+
)(torch.eye(forces.numel(), device=forces.device)).detach().cpu().numpy()
96+
else:
97+
hessian = np.zeros((len(forces), len(forces)))
98+
for i in range(len(forces)):
99+
hessian[:, i] = grad(
100+
-forces[i],
101+
positions,
102+
retain_graph=True,
103+
)[0].flatten().detach().cpu().numpy()
104+
105+
# Turn off create_graph for the first derivative
106+
calc.predictor.model.module.output_heads['energyandforcehead'].head.training = False
107+
108+
return hessian
109+
110+
def _ensure_writable_arrays_inplace(atoms):
111+
"""
112+
Ensure that all arrays in the atoms object are writable. Modifies the atoms object in place.
113+
"""
114+
for _key, _arr in list(atoms.arrays.items()):
115+
try:
116+
if hasattr(_arr, "flags") and not _arr.flags.writeable:
117+
atoms.arrays[_key] = np.array(_arr, copy=True)
118+
except Exception:
119+
atoms.arrays[_key] = np.array(_arr, copy=True)
120+
121+
122+
def _run_uma(atoms: ase.Atoms, charge: int, n_unpaired: int, device: str, fmax: float, logfile: Union[str, None], method: str, steps: int, task_name: str, tempdir: str, frequencies: bool) -> dict:
123+
"""
124+
Run UMA relaxation. The output is a dictionary with all relevant results. It is intention that this function does not return a uma calc object, since it is not pickable and might not be returnable on hpc parallel runs with ray or joblib. For local, non-parallel runs, one could totally modify this function to return the calc object as well.
125+
:return: dict
126+
"""
127+
start_time = datetime.now()
128+
if device == 'cuda' and not torch.cuda.is_available():
129+
print('CUDA device requested but not available. Falling back to CPU.')
130+
device = 'cpu'
131+
132+
_ensure_writable_arrays_inplace(atoms)
133+
predict_unit = get_mlip(method, device)
134+
calc = FAIRChemCalculator(predict_unit=predict_unit, task_name=task_name)
135+
atoms.calc = calc
136+
atoms.info = {'charge': charge, 'spin': n_unpaired+1} # UMA wants the multiplicity, not the number of unpaired electrons
137+
traj_path = Path(tempdir, 'uma_relaxation.xyz')
138+
opt = LBFGS(atoms, trajectory=str(traj_path), logfile=logfile) # Traj must be str, not Path()
139+
opt.run(fmax=fmax, steps=steps)
140+
141+
if frequencies:
142+
raise NotImplementedError('The frequency calculation is not yet implemented. You can implement it here easily using the uma_get_hessian function in the code above and then calculating frequencies from the Hessian matrix and the Gibbs energies from the frequencies using the ase thermochemistry module.')
143+
144+
# Avoid serialization issues with Ray/joblib by removing the calc from the atoms object
145+
relaxed_atoms = opt.atoms.copy()
146+
relaxed_atoms.calc = None
147+
148+
# Return all relevant results. Can't return the calc and opt objects directly since they are not pickable by Ray/joblib.
149+
concat_atoms = read(traj_path, ':')
150+
energies = [atoms.calc.results['energy'] for atoms in concat_atoms]
151+
try:
152+
dE = energies[-2] - calc.results['energy']
153+
except IndexError:
154+
dE = np.nan
155+
final_forces = float(np.linalg.norm(atoms.get_forces(), axis=1).max())
156+
results = {
157+
'E': calc.results['energy'], # final energy
158+
'atoms': relaxed_atoms, # final relaxed atoms
159+
'forces': calc.results['forces'].tolist(), # final forces. Can be outcommented for less storage.
160+
'stress': calc.results['stress'].tolist(), # final stress
161+
'input': { # input parameters
162+
'charge': charge,
163+
'n_unpaired': n_unpaired,
164+
'method': method,
165+
'device': device,
166+
'task_name': task_name,
167+
'fmax': opt.fmax,
168+
'n_max_steps': opt.max_steps,
169+
},
170+
'opt': # optimization results
171+
{
172+
'n_steps': opt.nsteps, # number of optimization steps taken
173+
'f': final_forces, # final max force on any atom
174+
'dE': dE, # energy change in last step
175+
'converged': bool(final_forces < opt.fmax), # whether the optimization converged
176+
'H0': opt.H0,
177+
'energies': energies, # energies at each step
178+
'traj': concat_atoms, # trajectory as list of Atoms objects. Can be outcommented for less storage.
179+
'time': datetime.now() - start_time, # time taken for the optimization
180+
}
181+
}
182+
return results
183+
184+
def uma_relax_atoms(
185+
atoms: Union[ase.Atoms, Path, str],
186+
charge: int,
187+
n_unpaired: int,
188+
method: str = 'umas',
189+
fmax=0.05,
190+
steps=300,
191+
device='cpu',
192+
task_name='omol',
193+
frequencies: bool=False,
194+
logfile: str = None,
195+
timing: bool=False
196+
) -> dict:
197+
"""
198+
Relax a molecule using the FAIRChemCalculator.
199+
@param atoms: Atoms object to be relaxed or path to an xyz file.
200+
@param charge: Charge of the molecule.
201+
@param n_unpaired: Number of unpaired electrons of the molecule.
202+
@param method: 'umas' or 'umam', which corresponds to the predictor to be used for the relaxation.
203+
@param fmax: Maximum convergence force for the relaxation. A value of 0.05 will lead to a convergence dE of around 1E-3 eV for a typical TMC.
204+
@param steps: Number of steps for the relaxation. If 0, a single-point calculation is performed.
205+
@param device: Device to be used for the calculation. 'cpu' or 'cuda' (if available). Default is 'cpu'.
206+
@param task_name: Task name for the calculation. For molecules use 'omol', for others look into the FAIRChem documentation.
207+
@param frequencies: Whether to calculate the frequencies and return a Gibbs energy or not. Not yet implemented but possible.
208+
@param logfile: Logfile to save the output of the optimization. None to suppress output, '-' to print to stdout.
209+
@param timing: Whether to time the optimization or not.
210+
@return: A dictionary with the results of the relaxation, including energy, forces, stress, relaxed atoms, and other relevant information.
211+
"""
212+
try: # Try to read the atoms object from a file
213+
atoms = read(atoms)
214+
except AttributeError:
215+
atoms = atoms.copy() # If atoms is already an Atoms object, copy it to avoid modifying the original
216+
coords_before = deepcopy(atoms.get_positions()) # For comparison after the optimization
217+
218+
_ensure_writable_arrays_inplace(atoms)
219+
with tempfile.TemporaryDirectory() as tempdir:
220+
# Set up the FAIRChemCalculator with the specified method and run the optimization
221+
results = _run_uma(atoms=atoms, charge=charge, n_unpaired=n_unpaired, device=device, fmax=fmax, logfile=logfile, method=method, steps=steps, task_name=task_name, tempdir=tempdir, frequencies=frequencies)
222+
relaxed_atoms = results['atoms']
223+
224+
if timing:
225+
print(f'UMA opt took {results["opt"]["time"]} seconds for {results["opt"]["n_steps"]} steps.')
226+
227+
if steps > 0 and np.allclose(coords_before, relaxed_atoms.get_positions()):
228+
warnings.warn(f'Atoms did not move during the optimization. This might be due to a too high fmax ({fmax}) or too few steps ({steps}). Consider increasing these values.')
229+
230+
return results
231+
232+
233+
234+
235+
236+
if __name__ == '__main__':
237+
238+
# Todo once only: login to huggingface to download and cache the used model
239+
# from huggingface_hub import login; login(token='...') # provide here your huggingface token after registering. Do this just once for every model. Never share your token publicly, e.g. never commit the code to git with the token in it.
240+
241+
########## Example: Relax a water molecule with the small UMA model ##########
242+
atoms = molecule('H2O') # ase.Atoms object or path to an xyz file. Here we create a water molecule using ASE.
243+
method = 'umas' # 'umas' (small) or 'umam' (medium). The small model is faster but less accurate.
244+
charge = 0 # Total charge of the molecule.
245+
n_unpaired = 0 # Number of unpaired electrons of the molecule.
246+
fmax = 0.05 # Max. converge force in eV/A. fmax=0.05 will give a convergence dE of around 1E-3 eV for a typical TMC.
247+
steps = 300 # Maximum number of steps for the relaxation. 0 means single-point calculation.
248+
# Other options:
249+
logfile = None # Logfile to save the optimization output. None to suppress output, '-' to print to stdout.
250+
timing = True # Whether to print timing information or not.
251+
# NOT YET IMPLEMENTED:
252+
frequencies = False # Whether to calculate frequencies and Gibbs energy or not.
253+
device = 'cpu' # Currently only 'cpu' is supported. For using uma on a gpu with 'cuda', best ask Timo.
254+
255+
results = uma_relax_atoms(
256+
atoms=atoms,
257+
charge=charge,
258+
method=method,
259+
n_unpaired=n_unpaired,
260+
fmax=fmax,
261+
steps=steps,
262+
frequencies=frequencies,
263+
device=device,
264+
logfile=logfile,
265+
timing=timing
266+
)
267+
268+
# Optional: uncomment to view optimization trajectory in ase gui.
269+
# from ase.visualize import view
270+
# view(results['opt']['traj']) # List of Atoms objects representing the trajectory
271+
272+
print('Done!')

0 commit comments

Comments
 (0)