Common gotchas with OCP#

%run ocp-tutorial.ipynb

OutOfMemoryError#

If you see errors like:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 390.00 MiB (GPU 0; 10.76 GiB total capacity; 9.59 GiB already allocated; 170.06 MiB free; 9.81 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It means your GPU is out of memory. Some reasons could be that you have multiple notebooks open that are using the GPU, e.g. they have loaded a calculator or something. Try closing all the other notebooks.

It could also mean the batch size is too large to fit in memory. You can try making it smaller in the yml config file (optim.batch_size).

It is recommended you use automatic mixed precision, –amp, in the options to main.py, or in the config.yml.

If it is an option, you can try a GPU with more memory, or you may be able to split the job over multiple GPUs.

I want the energy of a gas phase atom#

But I get an error like

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

The problem here is that no neighbors are found for the single atom which causes an error. This may be model dependent. There is currently no way to get atomic energies for some models.

%%capture
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
cp = "gnoc_oc22_oc20_all_s2ef.pt"
calc = OCPCalculator(checkpoint=cp)
from ase.build import bulk
atoms = bulk('Cu', a=10)
atoms.set_calculator(calc)
atoms.get_potential_energy()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 4
      2 atoms = bulk('Cu', a=10)
      3 atoms.set_calculator(calc)
----> 4 atoms.get_potential_energy()

File /opt/conda/lib/python3.9/site-packages/ase/atoms.py:731, in Atoms.get_potential_energy(self, force_consistent, apply_constraint)
    728     energy = self._calc.get_potential_energy(
    729         self, force_consistent=force_consistent)
    730 else:
--> 731     energy = self._calc.get_potential_energy(self)
    732 if apply_constraint:
    733     for constraint in self.constraints:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:709, in Calculator.get_potential_energy(self, atoms, force_consistent)
    708 def get_potential_energy(self, atoms=None, force_consistent=False):
--> 709     energy = self.get_property('energy', atoms)
    710     if force_consistent:
    711         if 'free_energy' not in self.results:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:737, in Calculator.get_property(self, name, atoms, allow_calculation)
    735     if not allow_calculation:
    736         return None
--> 737     self.calculate(atoms, [name], system_changes)
    739 if name not in self.results:
    740     # For some reason the calculator was not able to do what we want,
    741     # and that is OK.
    742     raise PropertyNotImplementedError('{} not present in this '
    743                                       'calculation'.format(name))

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/common/relaxation/ase_utils.py:202, in OCPCalculator.calculate(self, atoms, properties, system_changes)
    199 data_object = self.a2g.convert(atoms)
    200 batch = data_list_collater([data_object], otf_graph=True)
--> 202 predictions = self.trainer.predict(
    203     batch, per_image=False, disable_tqdm=True
    204 )
    205 if self.trainer.name == "s2ef":
    206     self.results["energy"] = predictions["energy"].item()

File /opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/trainers/forces_trainer.py:193, in ForcesTrainer.predict(self, data_loader, per_image, results_file, disable_tqdm)
    185 for i, batch_list in tqdm(
    186     enumerate(data_loader),
    187     total=len(data_loader),
   (...)
    190     disable=disable_tqdm,
    191 ):
    192     with torch.cuda.amp.autocast(enabled=self.scaler is not None):
--> 193         out = self._forward(batch_list)
    195     if self.normalizers is not None and "target" in self.normalizers:
    196         out["energy"] = self.normalizers["target"].denorm(
    197             out["energy"]
    198         )

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/trainers/forces_trainer.py:446, in ForcesTrainer._forward(self, batch_list)
    443 def _forward(self, batch_list):
    444     # forward pass.
    445     if self.config["model_attributes"].get("regress_forces", True):
--> 446         out_energy, out_forces = self.model(batch_list)
    447     else:
    448         out_energy = self.model(batch_list)

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/common/data_parallel.py:58, in OCPDataParallel.forward(self, batch_list, **kwargs)
     56 def forward(self, batch_list, **kwargs):
     57     if self.cpu:
---> 58         return self.module(batch_list[0])
     60     if len(self.device_ids) == 1:
     61         return self.module(
     62             batch_list[0].to(f"cuda:{self.device_ids[0]}"), **kwargs
     63         )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/common/utils.py:135, in conditional_grad.<locals>.decorator.<locals>.cls_method(self, *args, **kwargs)
    133 if self.regress_forces and not getattr(self, "direct_forces", 0):
    134     f = dec(func)
--> 135 return f(self, *args, **kwargs)

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1253, in GemNetOC.forward(self, data)
   1231 (
   1232     main_graph,
   1233     a2a_graph,
   (...)
   1240     quad_idx,
   1241 ) = self.get_graphs_and_indices(data)
   1242 _, idx_t = main_graph["edge_index"]
   1244 (
   1245     basis_rad_raw,
   1246     basis_atom_update,
   1247     basis_output,
   1248     bases_qint,
   1249     bases_e2e,
   1250     bases_a2e,
   1251     bases_e2a,
   1252     basis_a2a_rad,
-> 1253 ) = self.get_bases(
   1254     main_graph=main_graph,
   1255     a2a_graph=a2a_graph,
   1256     a2ee2a_graph=a2ee2a_graph,
   1257     qint_graph=qint_graph,
   1258     trip_idx_e2e=trip_idx_e2e,
   1259     trip_idx_a2e=trip_idx_a2e,
   1260     trip_idx_e2a=trip_idx_e2a,
   1261     quad_idx=quad_idx,
   1262     num_atoms=num_atoms,
   1263 )
   1265 # Embedding block
   1266 h = self.atom_emb(atomic_numbers)

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1124, in GemNetOC.get_bases(self, main_graph, a2a_graph, a2ee2a_graph, qint_graph, trip_idx_e2e, trip_idx_a2e, trip_idx_e2a, quad_idx, num_atoms)
   1115     cosφ_cab_q, cosφ_abd, angle_cabd = self.calculate_quad_angles(
   1116         main_graph["vector"],
   1117         qint_graph["vector"],
   1118         quad_idx,
   1119     )
   1121     basis_rad_cir_qint_raw, basis_cir_qint_raw = self.cbf_basis_qint(
   1122         qint_graph["distance"], cosφ_abd
   1123     )
-> 1124     basis_rad_sph_qint_raw, basis_sph_qint_raw = self.sbf_basis_qint(
   1125         main_graph["distance"],
   1126         cosφ_cab_q[quad_idx["trip_out_to_quad"]],
   1127         angle_cabd,
   1128     )
   1129 if self.atom_edge_interaction:
   1130     basis_rad_a2ee2a_raw = self.radial_basis_aeaint(
   1131         a2ee2a_graph["distance"]
   1132     )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:136, in SphericalBasisLayer.forward(self, D_ca, cosφ_cab, θ_cabd)
    134 def forward(self, D_ca, cosφ_cab, θ_cabd):
    135     rad_basis = self.radial_basis(D_ca)
--> 136     sph_basis = self.spherical_basis(cosφ_cab, θ_cabd)
    137     # (num_quadruplets, num_spherical**2)
    139     if self.scale_basis:

File ~/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:117, in SphericalBasisLayer.__init__.<locals>.<lambda>(cosφ, θ)
    113 elif sbf_name == "legendre_outer":
    114     circular_basis = get_sph_harm_basis(
    115         num_spherical, zero_m_only=True
    116     )
--> 117     self.spherical_basis = lambda cosφ, ϑ: (
    118         circular_basis(cosφ)[:, :, None]
    119         * circular_basis(torch.cos(ϑ))[:, None, :]
    120     ).reshape(cosφ.shape[0], -1)
    122 elif sbf_name == "gaussian_outer":
    123     self.circular_basis = GaussianBasis(
    124         start=-1, stop=1, num_gaussians=num_spherical, **sbf_hparams
    125     )

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

I get wildly different energies from the different models#

Some models are trained on adsorption energies, and some are trained on total energies. You have to know which one you are using.

Sometimes you can tell by the magnitude of energies, but you should use care with this. If energies are “small” and near zero they are likely adsorption energies. If energies are “large” in magnitude they are probably total energies. This can be misleading though, as it depends on the total number of atoms in the systems.

%run ocp-tutorial.ipynb

# These are to suppress the output from making the calculators.
from io import StringIO
import contextlib
from ase.build import fcc111, add_adsorbate
slab = fcc111('Pt', size=(2, 2, 5), vacuum=10.0)
add_adsorbate(slab, 'O', height=1.2, position='fcc')
# OC20 model - trained on adsorption energies
checkpoint = get_checkpoint('GemNet-OC All')

with contextlib.redirect_stdout(StringIO()) as _:
    calc = OCPCalculator(checkpoint=os.path.expanduser(checkpoint), cpu=False)
    


slab.set_calculator(calc)
slab.get_potential_energy()
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_all.pt
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
1.2862205505371094
# An OC22 checkpoint - trained on total energy
checkpoint = get_checkpoint('GemNet-OC OC22')

with contextlib.redirect_stdout(StringIO()) as _:
    calc = OCPCalculator(checkpoint=checkpoint, cpu=False)
    


slab.set_calculator(calc)
slab.get_potential_energy()
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_oc22_all_s2ef.pt
WARNING:root:Unable to identify OCP trainer, defaulting to `forces`. Specify the `trainer` argument into OCPCalculator if otherwise.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
-111.44766998291016
# This eSCN model is trained on adsorption energies
checkpoint = get_checkpoint('eSCN-L4-M2-Lay12 2M')

with contextlib.redirect_stdout(StringIO()) as _:
    calc = OCPCalculator(checkpoint=checkpoint, cpu=False)

slab.set_calculator(calc)
slab.get_potential_energy()
1.6795454025268555

Miscellaneous warnings#

In general, warnings are not errors.

Unrecognized arguments#

With Gemnet models you might see warnings like:

WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']

You can ignore this warning, it is not important for predictions.

Unable to identify OCP trainer#

The trainer is not specified in some checkpoints, and defaults to forces which means energy and forces are calculated. This is the default for the ASE OCP calculator, and this warning just alerts you it is setting that.

WARNING:root:Unable to identify OCP trainer, defaulting to `forces`. Specify the `trainer` argument into OCPCalculator if otherwise.

Request entity too large - can’t save your Notebook#

If you run commands that generate a lot of output in a notebook, sometimes the Jupyter notebook will become too large to save. It is kind of sad, the only thing I know to do is delete the output of the cell. Then maybe you can save it.

A solution after you know it happens is redirect output to a file.

This has happened when running training in a notebook where there are too many lines of output, or if you have a lot (20+) of inline images.

You need at least four atoms for molecules with some models#

Gemnet in particular seems to require at least 4 atoms. This has to do with interactions between atoms and their neighbors.

%%capture
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
import os
cp = checkpoint = get_checkpoint('GemNet-OC OC20+OC22')
calc = OCPCalculator(checkpoint=cp)
from ase.build import molecule
atoms = molecule('H2O')
atoms.set_tags(np.ones(len(atoms)))
atoms.set_calculator(calc)
atoms.get_potential_energy()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[20], line 5
      3 atoms.set_tags(np.ones(len(atoms)))
      4 atoms.set_calculator(calc)
----> 5 atoms.get_potential_energy()

File /opt/conda/lib/python3.9/site-packages/ase/atoms.py:731, in Atoms.get_potential_energy(self, force_consistent, apply_constraint)
    728     energy = self._calc.get_potential_energy(
    729         self, force_consistent=force_consistent)
    730 else:
--> 731     energy = self._calc.get_potential_energy(self)
    732 if apply_constraint:
    733     for constraint in self.constraints:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:709, in Calculator.get_potential_energy(self, atoms, force_consistent)
    708 def get_potential_energy(self, atoms=None, force_consistent=False):
--> 709     energy = self.get_property('energy', atoms)
    710     if force_consistent:
    711         if 'free_energy' not in self.results:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:737, in Calculator.get_property(self, name, atoms, allow_calculation)
    735     if not allow_calculation:
    736         return None
--> 737     self.calculate(atoms, [name], system_changes)
    739 if name not in self.results:
    740     # For some reason the calculator was not able to do what we want,
    741     # and that is OK.
    742     raise PropertyNotImplementedError('{} not present in this '
    743                                       'calculation'.format(name))

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/relaxation/ase_utils.py:202, in OCPCalculator.calculate(self, atoms, properties, system_changes)
    199 data_object = self.a2g.convert(atoms)
    200 batch = data_list_collater([data_object], otf_graph=True)
--> 202 predictions = self.trainer.predict(
    203     batch, per_image=False, disable_tqdm=True
    204 )
    205 if self.trainer.name == "s2ef":
    206     self.results["energy"] = predictions["energy"].item()

File /opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/trainers/forces_trainer.py:194, in ForcesTrainer.predict(self, data_loader, per_image, results_file, disable_tqdm)
    186 for i, batch_list in tqdm(
    187     enumerate(data_loader),
    188     total=len(data_loader),
   (...)
    191     disable=disable_tqdm,
    192 ):
    193     with torch.cuda.amp.autocast(enabled=self.scaler is not None):
--> 194         out = self._forward(batch_list)
    196     if self.normalizers is not None and "target" in self.normalizers:
    197         out["energy"] = self.normalizers["target"].denorm(
    198             out["energy"]
    199         )

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/trainers/forces_trainer.py:449, in ForcesTrainer._forward(self, batch_list)
    446 def _forward(self, batch_list):
    447     # forward pass.
    448     if self.config["model_attributes"].get("regress_forces", True):
--> 449         out_energy, out_forces = self.model(batch_list)
    450     else:
    451         out_energy = self.model(batch_list)

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/data_parallel.py:58, in OCPDataParallel.forward(self, batch_list, **kwargs)
     56 def forward(self, batch_list, **kwargs):
     57     if self.cpu:
---> 58         return self.module(batch_list[0])
     60     if len(self.device_ids) == 1:
     61         return self.module(
     62             batch_list[0].to(f"cuda:{self.device_ids[0]}"), **kwargs
     63         )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/utils.py:135, in conditional_grad.<locals>.decorator.<locals>.cls_method(self, *args, **kwargs)
    133 if self.regress_forces and not getattr(self, "direct_forces", 0):
    134     f = dec(func)
--> 135 return f(self, *args, **kwargs)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1259, in GemNetOC.forward(self, data)
   1237 (
   1238     main_graph,
   1239     a2a_graph,
   (...)
   1246     quad_idx,
   1247 ) = self.get_graphs_and_indices(data)
   1248 _, idx_t = main_graph["edge_index"]
   1250 (
   1251     basis_rad_raw,
   1252     basis_atom_update,
   1253     basis_output,
   1254     bases_qint,
   1255     bases_e2e,
   1256     bases_a2e,
   1257     bases_e2a,
   1258     basis_a2a_rad,
-> 1259 ) = self.get_bases(
   1260     main_graph=main_graph,
   1261     a2a_graph=a2a_graph,
   1262     a2ee2a_graph=a2ee2a_graph,
   1263     qint_graph=qint_graph,
   1264     trip_idx_e2e=trip_idx_e2e,
   1265     trip_idx_a2e=trip_idx_a2e,
   1266     trip_idx_e2a=trip_idx_e2a,
   1267     quad_idx=quad_idx,
   1268     num_atoms=num_atoms,
   1269 )
   1271 # Embedding block
   1272 h = self.atom_emb(atomic_numbers)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1130, in GemNetOC.get_bases(self, main_graph, a2a_graph, a2ee2a_graph, qint_graph, trip_idx_e2e, trip_idx_a2e, trip_idx_e2a, quad_idx, num_atoms)
   1121     cosφ_cab_q, cosφ_abd, angle_cabd = self.calculate_quad_angles(
   1122         main_graph["vector"],
   1123         qint_graph["vector"],
   1124         quad_idx,
   1125     )
   1127     basis_rad_cir_qint_raw, basis_cir_qint_raw = self.cbf_basis_qint(
   1128         qint_graph["distance"], cosφ_abd
   1129     )
-> 1130     basis_rad_sph_qint_raw, basis_sph_qint_raw = self.sbf_basis_qint(
   1131         main_graph["distance"],
   1132         cosφ_cab_q[quad_idx["trip_out_to_quad"]],
   1133         angle_cabd,
   1134     )
   1135 if self.atom_edge_interaction:
   1136     basis_rad_a2ee2a_raw = self.radial_basis_aeaint(
   1137         a2ee2a_graph["distance"]
   1138     )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:136, in SphericalBasisLayer.forward(self, D_ca, cosφ_cab, θ_cabd)
    134 def forward(self, D_ca, cosφ_cab, θ_cabd):
    135     rad_basis = self.radial_basis(D_ca)
--> 136     sph_basis = self.spherical_basis(cosφ_cab, θ_cabd)
    137     # (num_quadruplets, num_spherical**2)
    139     if self.scale_basis:

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:117, in SphericalBasisLayer.__init__.<locals>.<lambda>(cosφ, θ)
    113 elif sbf_name == "legendre_outer":
    114     circular_basis = get_sph_harm_basis(
    115         num_spherical, zero_m_only=True
    116     )
--> 117     self.spherical_basis = lambda cosφ, ϑ: (
    118         circular_basis(cosφ)[:, :, None]
    119         * circular_basis(torch.cos(ϑ))[:, None, :]
    120     ).reshape(cosφ.shape[0], -1)
    122 elif sbf_name == "gaussian_outer":
    123     self.circular_basis = GaussianBasis(
    124         start=-1, stop=1, num_gaussians=num_spherical, **sbf_hparams
    125     )

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous

To tag or not?#

Some models use tags to determine which atoms to calculate energies for. For example, Gemnet uses a tag=1 to indicate the atom should be calculated. You will get an error with this model

%%capture
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
import os
cp = checkpoint = get_checkpoint('GemNet-OC OC20+OC22')
calc = OCPCalculator(checkpoint=cp)
atoms = molecule('CH4')
atoms.set_calculator(calc)
atoms.get_potential_energy()  # error
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[22], line 3
      1 atoms = molecule('CH4')
      2 atoms.set_calculator(calc)
----> 3 atoms.get_potential_energy()  # error

File /opt/conda/lib/python3.9/site-packages/ase/atoms.py:731, in Atoms.get_potential_energy(self, force_consistent, apply_constraint)
    728     energy = self._calc.get_potential_energy(
    729         self, force_consistent=force_consistent)
    730 else:
--> 731     energy = self._calc.get_potential_energy(self)
    732 if apply_constraint:
    733     for constraint in self.constraints:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:709, in Calculator.get_potential_energy(self, atoms, force_consistent)
    708 def get_potential_energy(self, atoms=None, force_consistent=False):
--> 709     energy = self.get_property('energy', atoms)
    710     if force_consistent:
    711         if 'free_energy' not in self.results:

File /opt/conda/lib/python3.9/site-packages/ase/calculators/calculator.py:737, in Calculator.get_property(self, name, atoms, allow_calculation)
    735     if not allow_calculation:
    736         return None
--> 737     self.calculate(atoms, [name], system_changes)
    739 if name not in self.results:
    740     # For some reason the calculator was not able to do what we want,
    741     # and that is OK.
    742     raise PropertyNotImplementedError('{} not present in this '
    743                                       'calculation'.format(name))

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/relaxation/ase_utils.py:202, in OCPCalculator.calculate(self, atoms, properties, system_changes)
    199 data_object = self.a2g.convert(atoms)
    200 batch = data_list_collater([data_object], otf_graph=True)
--> 202 predictions = self.trainer.predict(
    203     batch, per_image=False, disable_tqdm=True
    204 )
    205 if self.trainer.name == "s2ef":
    206     self.results["energy"] = predictions["energy"].item()

File /opt/conda/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/trainers/forces_trainer.py:194, in ForcesTrainer.predict(self, data_loader, per_image, results_file, disable_tqdm)
    186 for i, batch_list in tqdm(
    187     enumerate(data_loader),
    188     total=len(data_loader),
   (...)
    191     disable=disable_tqdm,
    192 ):
    193     with torch.cuda.amp.autocast(enabled=self.scaler is not None):
--> 194         out = self._forward(batch_list)
    196     if self.normalizers is not None and "target" in self.normalizers:
    197         out["energy"] = self.normalizers["target"].denorm(
    198             out["energy"]
    199         )

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/trainers/forces_trainer.py:449, in ForcesTrainer._forward(self, batch_list)
    446 def _forward(self, batch_list):
    447     # forward pass.
    448     if self.config["model_attributes"].get("regress_forces", True):
--> 449         out_energy, out_forces = self.model(batch_list)
    450     else:
    451         out_energy = self.model(batch_list)

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/data_parallel.py:58, in OCPDataParallel.forward(self, batch_list, **kwargs)
     56 def forward(self, batch_list, **kwargs):
     57     if self.cpu:
---> 58         return self.module(batch_list[0])
     60     if len(self.device_ids) == 1:
     61         return self.module(
     62             batch_list[0].to(f"cuda:{self.device_ids[0]}"), **kwargs
     63         )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/common/utils.py:135, in conditional_grad.<locals>.decorator.<locals>.cls_method(self, *args, **kwargs)
    133 if self.regress_forces and not getattr(self, "direct_forces", 0):
    134     f = dec(func)
--> 135 return f(self, *args, **kwargs)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1259, in GemNetOC.forward(self, data)
   1237 (
   1238     main_graph,
   1239     a2a_graph,
   (...)
   1246     quad_idx,
   1247 ) = self.get_graphs_and_indices(data)
   1248 _, idx_t = main_graph["edge_index"]
   1250 (
   1251     basis_rad_raw,
   1252     basis_atom_update,
   1253     basis_output,
   1254     bases_qint,
   1255     bases_e2e,
   1256     bases_a2e,
   1257     bases_e2a,
   1258     basis_a2a_rad,
-> 1259 ) = self.get_bases(
   1260     main_graph=main_graph,
   1261     a2a_graph=a2a_graph,
   1262     a2ee2a_graph=a2ee2a_graph,
   1263     qint_graph=qint_graph,
   1264     trip_idx_e2e=trip_idx_e2e,
   1265     trip_idx_a2e=trip_idx_a2e,
   1266     trip_idx_e2a=trip_idx_e2a,
   1267     quad_idx=quad_idx,
   1268     num_atoms=num_atoms,
   1269 )
   1271 # Embedding block
   1272 h = self.atom_emb(atomic_numbers)

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/gemnet_oc.py:1130, in GemNetOC.get_bases(self, main_graph, a2a_graph, a2ee2a_graph, qint_graph, trip_idx_e2e, trip_idx_a2e, trip_idx_e2a, quad_idx, num_atoms)
   1121     cosφ_cab_q, cosφ_abd, angle_cabd = self.calculate_quad_angles(
   1122         main_graph["vector"],
   1123         qint_graph["vector"],
   1124         quad_idx,
   1125     )
   1127     basis_rad_cir_qint_raw, basis_cir_qint_raw = self.cbf_basis_qint(
   1128         qint_graph["distance"], cosφ_abd
   1129     )
-> 1130     basis_rad_sph_qint_raw, basis_sph_qint_raw = self.sbf_basis_qint(
   1131         main_graph["distance"],
   1132         cosφ_cab_q[quad_idx["trip_out_to_quad"]],
   1133         angle_cabd,
   1134     )
   1135 if self.atom_edge_interaction:
   1136     basis_rad_a2ee2a_raw = self.radial_basis_aeaint(
   1137         a2ee2a_graph["distance"]
   1138     )

File /opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:136, in SphericalBasisLayer.forward(self, D_ca, cosφ_cab, θ_cabd)
    134 def forward(self, D_ca, cosφ_cab, θ_cabd):
    135     rad_basis = self.radial_basis(D_ca)
--> 136     sph_basis = self.spherical_basis(cosφ_cab, θ_cabd)
    137     # (num_quadruplets, num_spherical**2)
    139     if self.scale_basis:

File ~/shared-scratch/jkitchin/esunshine-ocp/ocpmodels/models/gemnet_oc/layers/spherical_basis.py:117, in SphericalBasisLayer.__init__.<locals>.<lambda>(cosφ, θ)
    113 elif sbf_name == "legendre_outer":
    114     circular_basis = get_sph_harm_basis(
    115         num_spherical, zero_m_only=True
    116     )
--> 117     self.spherical_basis = lambda cosφ, ϑ: (
    118         circular_basis(cosφ)[:, :, None]
    119         * circular_basis(torch.cos(ϑ))[:, None, :]
    120     ).reshape(cosφ.shape[0], -1)
    122 elif sbf_name == "gaussian_outer":
    123     self.circular_basis = GaussianBasis(
    124         start=-1, stop=1, num_gaussians=num_spherical, **sbf_hparams
    125     )

RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1] because the unspecified dimension size -1 can be any value and is ambiguous
atoms = molecule('CH4')
atoms.set_tags(np.ones(len(atoms)))  # <- critical line for Gemnet
atoms.set_calculator(calc)
atoms.get_potential_energy()
-23.71796226501465

Not all models require tags though. This eSCN model does not use them. This is another detail that is important to keep in mind.

%%capture
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
import os
cp = checkpoint = get_checkpoint('eSCN-L6-M3-Lay20 All+MD')
calc = OCPCalculator(checkpoint=cp)
atoms = molecule('CH4')

atoms.set_calculator(calc)
atoms.get_potential_energy()
-2.23504638671875

Stochastic simulation results#

Some models are not deterministic (SCN/eSCN.EqV2), i.e. you can get slightly different answers each time you run it. An example is shown below. See Open-Catalyst-Project/ocp#563 for more discussion. This happens because a random selection of is made to sample edges, and a different selection is made each time you run it.

%run ocp-tutorial.ipynb
checkpoint = get_checkpoint('eSCN-L6-M3-Lay20 All+MD')

from ocpmodels.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint=os.path.expanduser(checkpoint), cpu=True)

from ase.build import fcc111, add_adsorbate
from ase.optimize import BFGS
slab = fcc111('Pt', size=(2, 2, 5), vacuum=10.0)
add_adsorbate(slab, 'O', height=1.2, position='fcc')
slab.set_calculator(calc)

results = []
for i in range(10):
    calc.calculate(slab, ['energy'], None)
    results += [slab.get_potential_energy()]
    
import numpy as np
print(np.mean(results), np.std(results))
for result in results:
    print(result)
1.7137908697128297 0.002903242520056841
1.7156665325164795
1.7084295749664307
1.7146689891815186
1.711714267730713
1.7096374034881592
1.7145936489105225
1.7145006656646729
1.716435432434082
1.7138292789459229
1.718432903289795
%run ocp-tutorial.ipynb

import os
from ocpmodels.common.relaxation.ase_utils import OCPCalculator

for ckp in checkpoints:
    try:
        checkpoint = get_checkpoint(ckp)
        calc = OCPCalculator(checkpoint, cpu=True)
    except Exception as exc:
        print(ckp, exc)
    finally:
        os.unlink(checkpoint)
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_200k.pt
CGCNN 200k 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_2M.pt
CGCNN 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
CGCNN 20M argument of type 'NoneType' is not iterable
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_all.pt
CGCNN All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/dimenet_200k.pt
DimeNet 200k 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/dimenet_2M.pt
DimeNet 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_200k.pt
SchNet 200k 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_2M.pt
SchNet 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_20M.pt
SchNet 20M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_all_large.pt
SchNet All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_200k.pt
DimeNet++ 200k 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_2M.pt
DimeNet++ 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_20M.pt
DimeNet++ 20M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_all.pt
DimeNet++ All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_12/s2ef/spinconv_force_centric_2M.pt
SpinConv 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_08/s2ef/spinconv_force_centric_all.pt
SpinConv All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_12/s2ef/gemnet_t_direct_h512_2M.pt
GemNet-dT 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_08/s2ef/gemnet_t_direct_h512_all.pt
GemNet-dT All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
PaiNN All argument of type 'NoneType' is not iterable
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_2M.pt
GemNet-OC 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_all.pt
GemNet-OC All 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/gemnet_oc_base_s2ef_all_md.pt
GemNet-OC All+MD 'utf-8' codec can't decode byte 0x80 in position 128: invalid start byte
GemNet-OC-Large All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_t1_b1_s2ef_2M.pt
SCN 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
SCN-t4-b2 2M argument of type 'NoneType' is not iterable
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_all_md_s2ef.pt
SCN All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
eSCN-L4-M2-Lay12 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_2M_s2ef.pt
eSCN-L6-M2-Lay12 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_all_md_s2ef.pt
eSCN-L6-M2-Lay12 All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
eSCN-L6-M3-Lay20 All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_83M_2M.pt
EquiformerV2 (83M) 2M 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_31M_ec4_allmd.pt
EquiformerV2 (31M) All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_153M_ec4_allmd.pt
EquiformerV2 (153M) All+MD 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gndt_oc22_all_s2ef.pt
GemNet-dT OC22 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
GemNet-OC OC22 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
GemNet-OC OC20+OC22 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_05/oc22/s2ef/gnoc_oc22_oc20_all_s2ef.pt
GemNet-OC trained with `enforce_max_neighbors_strictly=False` #467 OC20+OC22 'utf-8' codec can't decode byte 0x80 in position 128: invalid start byte
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_finetune_all_s2ef.pt
GemNet-OC OC20->OC22 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte

The forces don’t sum to zero#

In DFT, the forces on all the atoms should sum to zero; otherwise, there is a net translational or rotational force present. This is not enforced in OCP models. Instead, individual forces are predicted, with no constraint that they sum to zero. If the force predictions are very accurate, then they sum close to zero. You can further improve this if you subtract the mean force from each atom.

%run ocp-tutorial.ipynb
checkpoint = get_checkpoint('eSCN-L6-M3-Lay20 All+MD')

from ocpmodels.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint=os.path.expanduser(checkpoint), cpu=True)

from ase.build import fcc111, add_adsorbate
from ase.optimize import BFGS
slab = fcc111('Pt', size=(2, 2, 5), vacuum=10.0)
add_adsorbate(slab, 'O', height=1.2, position='fcc')
slab.set_calculator(calc)

f = slab.get_forces()
f.sum(axis=0)
Downloading https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt
array([-0.00371197, -0.01800631,  0.01127684], dtype=float32)
# This makes them sum closer to zero by removing net translational force
(f - f.mean(axis=0)).sum(axis=0)
array([ 1.2270175e-07,  7.5437129e-08, -1.1920929e-07], dtype=float32)