Convenience functions for getting ocp paths#
import ocpmodels as om
from pathlib import Path
def ocp_root():
"""Return the root directory of the installed ocp package."""
return Path(om.__file__).parent.parent
def ocp_main():
"""Return the path to ocp main.py"""
return ocp_root() / "main.py"
import subprocess
import sys
import numba
import numpy as np
import ase
import e3nn
import pymatgen.core as pc
import torch
import torch.cuda as tc
import torch_geometric as tg
import platform
import psutil
def describe_ocp():
"""Print some system information that could be useful in debugging."""
print(sys.executable, sys.version)
print(f'ocp is installed at {ocp_root()}')
commit_hash = (
subprocess.check_output(
[
"git",
"-C",
om.__path__[0],
"describe",
"--always",
]
)
.strip()
.decode("ascii")
)
print(f'ocp repo is at git commit: {commit_hash}')
print(f'numba: {numba.__version__}')
print(f'numpy: {np.version.version}')
print(f'ase: {ase.__version__}')
print(f'e3nn: {e3nn.__version__}')
print(f'pymatgen: {pc.__version__}')
print(f'torch: {torch.version.__version__}')
print(f'torch.version.cuda: {torch.version.cuda}')
print(f'torch.cuda: is_available: {tc.is_available()}')
if tc.is_available():
print(' __CUDNN VERSION:', torch.backends.cudnn.version())
print(' __Number CUDA Devices:', torch.cuda.device_count())
print(' __CUDA Device Name:',torch.cuda.get_device_name(0))
print(' __CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)
print(f'torch geometric: {tg.__version__}')
print()
print(f'Platform: {platform.platform()}')
print(f' Processor: {platform.processor()}')
print(f' Virtual memory: {psutil.virtual_memory()}')
print(f' Swap memory: {psutil.swap_memory()}')
print(f' Disk usage: {psutil.disk_usage("/")}')
Convenience function for getting checkpoints#
import urllib
import os
from pathlib import Path
import requests
checkpoints = {
# Open Catalyst 2020 (OC20)
'CGCNN 200k' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_200k.pt',
'CGCNN 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_2M.pt',
'CGCNN 20M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_20M.pt',
'CGCNN All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/cgcnn_all.pt',
'DimeNet 200k' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/dimenet_200k.pt',
'DimeNet 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/dimenet_2M.pt',
'SchNet 200k' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_200k.pt',
'SchNet 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_2M.pt',
'SchNet 20M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_20M.pt',
'SchNet All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_all_large.pt',
'DimeNet++ 200k' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_200k.pt',
'DimeNet++ 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_2M.pt',
'DimeNet++ 20M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_20M.pt',
'DimeNet++ All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_02/s2ef/dimenetpp_all.pt',
'SpinConv 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_12/s2ef/spinconv_force_centric_2M.pt',
'SpinConv All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_08/s2ef/spinconv_force_centric_all.pt',
'GemNet-dT 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_12/s2ef/gemnet_t_direct_h512_2M.pt',
'GemNet-dT All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2021_08/s2ef/gemnet_t_direct_h512_all.pt',
'PaiNN All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_05/s2ef/painn_h512_s2ef_all.pt',
'GemNet-OC 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_2M.pt',
'GemNet-OC All' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_base_s2ef_all.pt',
'GemNet-OC All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/gemnet_oc_base_s2ef_all_md.pt',
'GemNet-OC-Large All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_07/s2ef/gemnet_oc_large_s2ef_all_md.pt',
'SCN 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_t1_b1_s2ef_2M.pt',
'SCN-t4-b2 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_t4_b2_s2ef_2M.pt',
'SCN All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_all_md_s2ef.pt',
'eSCN-L4-M2-Lay12 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l4_m2_lay12_2M_s2ef.pt',
'eSCN-L6-M2-Lay12 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_2M_s2ef.pt',
'eSCN-L6-M2-Lay12 All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m2_lay12_all_md_s2ef.pt',
'eSCN-L6-M3-Lay20 All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt',
'EquiformerV2 (83M) 2M' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_83M_2M.pt',
'EquiformerV2 (31M) All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_31M_ec4_allmd.pt',
'EquiformerV2 (153M) All+MD' :'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_153M_ec4_allmd.pt',
# Open Catalyst 2022 (OC22)
'GemNet-dT OC22' : 'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gndt_oc22_all_s2ef.pt',
'GemNet-OC OC22' : 'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_oc22_all_s2ef.pt',
'GemNet-OC OC20+OC22' : 'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_oc22_oc20_all_s2ef.pt',
'GemNet-OC trained with `enforce_max_neighbors_strictly=False` #467 OC20+OC22' : 'https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_05/oc22/s2ef/gnoc_oc22_oc20_all_s2ef.pt',
'GemNet-OC OC20->OC22' : 'https://dl.fbaipublicfiles.com/opencatalystproject/models/2022_09/oc22/s2ef/gnoc_finetune_all_s2ef.pt'
}
def list_checkpoints():
"""List checkpoints that are available to download."""
print('See https://github.com/Open-Catalyst-Project/ocp/blob/main/MODELS.md for more details.')
for key in checkpoints:
print(key)
print('Copy one of these keys to get_checkpoint(key) to download it.')
def get_checkpoint(key):
"""Download a checkpoint.
key: string in checkpoints.
Returns name of checkpoint that was saved.
"""
url = checkpoints.get(key, None)
if url is None:
raise Exception('No url found for {key}')
pt = Path(urllib.parse.urlparse(url).path).name
if not os.path.exists(pt):
with open(pt, 'wb') as f:
print(f'Downloading {url}')
f.write(requests.get(url).content)
return pt
Train/test/val split for an ase db#
from pathlib import Path
import numpy as np
from ase.db import connect
def train_test_val_split(ase_db, ttv=(0.8, 0.1, .1), files=('train.db', 'test.db', 'val.db'), seed=42):
"""Split an ase db into train, test and validation dbs.
ase_db: path to an ase db containing all the data.
ttv: a tuple containing the fraction of train, test and val data. This will be normalized.
files: a tuple of filenames to write the splits into. An exception is raised if these exist.
You should delete them first.
seed: an integer for the random number generator seed
Returns the absolute path to files.
"""
for db in files:
if os.path.exists(db):
raise Exception('{db} exists. Please delete it before proceeding.')
src = connect(ase_db)
N = src.count()
ttv = np.array(ttv)
ttv /= ttv.sum()
train_end = int(N * ttv[0])
test_end = train_end + int(N * ttv[1])
train = connect(files[0])
test = connect(files[1])
val = connect(files[2])
ids = np.arange(1, N + 1)
rng = np.random.default_rng(seed=42)
rng.shuffle(ids)
for _id in ids[0:train_end]:
row = src.get(id=int(_id))
train.write(row.toatoms())
for _id in ids[train_end:test_end]:
row = src.get(id=int(_id))
test.write(row.toatoms())
for _id in ids[test_end:]:
row = src.get(id=int(_id))
val.write(row.toatoms())
return [Path(f).absolute() for f in files]
Generating a config from a checkpoint#
from yaml import load, dump
from yaml import CLoader as Loader, CDumper as Dumper
import torch
import os
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
from io import StringIO
import sys
import contextlib
def generate_yml_config(checkpoint_path, yml='run.yml', delete=(), update=()):
"""Generate a yml config file from an existing checkpoint file.
checkpoint_path: string to path of an existing checkpoint
yml: name of file to write to.
pop: list of keys to remove from the config
update: dictionary of key:values to update
Use a dot notation in update.
Returns an absolute path to the generated yml file.
"""
# You can't just read in the checkpoint with torch. The calculator does some things to it.
# Rather than recreate that here I just reuse the calculator machinery. I don't want to
# see the output though, so I capture it.
with contextlib.redirect_stdout(StringIO()) as _:
config = OCPCalculator(checkpoint=checkpoint_path).config
for key in delete:
if key in config and len(key.split('.')) == 1:
del config[key]
else:
keys = key.split('.')
if keys[0] in config:
d = config[keys[0]]
else:
continue
if isinstance(d, dict):
for k in keys[1:]:
if isinstance(d[k], dict):
d = d[k]
else:
if k in d:
del d[k]
def nested_set(dic, keys, value):
for key in keys[:-1]:
dic = dic.setdefault(key, {})
dic[keys[-1]] = value
for _key in update:
keys = _key.split('.')
nested_set(config, keys, update[_key])
out = dump(config)
with open(yml, 'wb') as f:
f.write(out.encode('utf-8'))
return Path(yml).absolute()