Convenience functions for getting ocp paths#

import ocpmodels as om
from pathlib import Path

def ocp_root():
    """Return the root directory of the installed ocp package."""
    return Path(om.__file__).parent.parent

def ocp_main():
    """Return the path to ocp"""
    return ocp_root() / ""
import subprocess
import sys
import numba
import numpy as np
import ase
import e3nn
import pymatgen.core as pc
import torch
import torch.cuda as tc
import torch_geometric as tg
import platform
import psutil

def describe_ocp():
    """Print some system information that could be useful in debugging."""
    print(sys.executable, sys.version)  
    print(f'ocp is installed at {ocp_root()}')
    commit_hash = (
    print(f'ocp repo is at git commit: {commit_hash}')
    print(f'numba: {numba.__version__}')
    print(f'numpy: {np.version.version}')
    print(f'ase: {ase.__version__}')
    print(f'e3nn: {e3nn.__version__}')
    print(f'pymatgen: {pc.__version__}')
    print(f'torch: {torch.version.__version__}')
    print(f'torch.version.cuda: {torch.version.cuda}')
    print(f'torch.cuda: is_available: {tc.is_available()}')
    if tc.is_available():
        print('  __CUDNN VERSION:', torch.backends.cudnn.version())
        print('  __Number CUDA Devices:', torch.cuda.device_count())
        print('  __CUDA Device Name:',torch.cuda.get_device_name(0))
        print('  __CUDA Device Total Memory [GB]:',torch.cuda.get_device_properties(0).total_memory/1e9)
    print(f'torch geometric: {tg.__version__}')    
    print(f'Platform: {platform.platform()}')
    print(f'  Processor: {platform.processor()}')
    print(f'  Virtual memory: {psutil.virtual_memory()}')
    print(f'  Swap memory: {psutil.swap_memory()}')
    print(f'  Disk usage: {psutil.disk_usage("/")}')

Convenience function for getting checkpoints#

import urllib
import os
from pathlib import Path
import requests

checkpoints = {
    # Open Catalyst 2020 (OC20)
    'CGCNN 200k'	:'', 
    'CGCNN 2M'	    :'', 
    'CGCNN 20M'	:'', 
    'CGCNN All'	:'', 
    'DimeNet 200k'	:'',
    'DimeNet 2M'	:'',
    'SchNet 200k'	:'', 
    'SchNet 2M'	    :'', 
    'SchNet 20M'	:'', 
    'SchNet All'	:'', 
    'DimeNet++ 200k'   :'', 
    'DimeNet++ 2M'     :'', 
    'DimeNet++ 20M'    :'', 
    'DimeNet++ All'    :'', 
    'SpinConv 2M'    :'', 
    'SpinConv All'    :'', 
    'GemNet-dT 2M'    :'', 
    'GemNet-dT All'    :'', 
    'PaiNN All'    :'', 
    'GemNet-OC 2M'     :'', 
    'GemNet-OC All'    :'', 
    'GemNet-OC All+MD'    :'', 
    'GemNet-OC-Large All+MD' :'', 
    'SCN 2M'   :'', 
    'SCN-t4-b2 2M'    :'', 
    'SCN All+MD' :'', 
    'eSCN-L4-M2-Lay12 2M'     :'', 
    'eSCN-L6-M2-Lay12 2M'    :'', 
    'eSCN-L6-M2-Lay12 All+MD'     :'', 
    'eSCN-L6-M3-Lay20 All+MD'     :'', 
    'EquiformerV2 (83M) 2M'     :'', 
    'EquiformerV2 (31M) All+MD'     :'', 
    'EquiformerV2 (153M) All+MD'     :'', 
    # Open Catalyst 2022 (OC22)
    'GemNet-dT OC22'	: '',
    'GemNet-OC OC22'	: '',
    'GemNet-OC OC20+OC22'	: '',
    'GemNet-OC trained with `enforce_max_neighbors_strictly=False` #467 OC20+OC22' : '',
    'GemNet-OC OC20->OC22'	: ''

def list_checkpoints():
    """List checkpoints that are available to download."""
    print('See for more details.')
    for key in checkpoints:
    print('Copy one of these keys to get_checkpoint(key) to download it.')

def get_checkpoint(key):
    """Download a checkpoint.
    key: string in checkpoints.
    Returns name of checkpoint that was saved.
    url = checkpoints.get(key, None)
    if url is None:
        raise Exception('No url found for {key}')
    pt = Path(urllib.parse.urlparse(url).path).name
    if not os.path.exists(pt):
        with open(pt, 'wb') as f:
            print(f'Downloading {url}')
    return pt      

Train/test/val split for an ase db#

from pathlib import Path
import numpy as np
from ase.db import connect

def train_test_val_split(ase_db, ttv=(0.8, 0.1, .1), files=('train.db', 'test.db', 'val.db'), seed=42):
    """Split an ase db into train, test and validation dbs.
    ase_db: path to an ase db containing all the data.
    ttv: a tuple containing the fraction of train, test and val data. This will be normalized.
    files: a tuple of filenames to write the splits into. An exception is raised if these exist. 
           You should delete them first.
    seed: an integer for the random number generator seed
    Returns the absolute path to files.
    for db in files:
        if os.path.exists(db):
            raise Exception('{db} exists. Please delete it before proceeding.')
    src = connect(ase_db)
    N = src.count()
    ttv = np.array(ttv)
    ttv /= ttv.sum()
    train_end = int(N * ttv[0])
    test_end = train_end + int(N * ttv[1])
    train = connect(files[0])
    test = connect(files[1])
    val = connect(files[2])
    ids = np.arange(1, N + 1)
    rng = np.random.default_rng(seed=42)
    for _id in ids[0:train_end]:
        row = src.get(id=int(_id))
    for _id in ids[train_end:test_end]:
        row = src.get(id=int(_id))
    for _id in ids[test_end:]:
        row = src.get(id=int(_id))
    return [Path(f).absolute() for f in files]

Generating a config from a checkpoint#

from yaml import load, dump
from yaml import CLoader as Loader, CDumper as Dumper
import torch
import os
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
from io import StringIO
import sys
import contextlib

def generate_yml_config(checkpoint_path, yml='run.yml', delete=(), update=()):
    """Generate a yml config file from an existing checkpoint file.
    checkpoint_path: string to path of an existing checkpoint
    yml: name of file to write to.
    pop: list of keys to remove from the config
    update: dictionary of key:values to update
    Use a dot notation in update.
    Returns an absolute path to the generated yml file.
    # You can't just read in the checkpoint with torch. The calculator does some things to it. 
    # Rather than recreate that here I just reuse the calculator machinery. I don't want to 
    # see the output though, so I capture it.

    with contextlib.redirect_stdout(StringIO()) as _:
        config = OCPCalculator(checkpoint=checkpoint_path).config
    for key in delete:
        if key in config and len(key.split('.')) == 1:
            del config[key]
            keys = key.split('.')
            if keys[0] in config:
                d = config[keys[0]]
            if isinstance(d, dict):
                for k in keys[1:]:
                    if isinstance(d[k], dict):
                        d = d[k]
                        if k in d:
                            del d[k]
    def nested_set(dic, keys, value):
        for key in keys[:-1]:
            dic = dic.setdefault(key, {})
        dic[keys[-1]] = value
    for _key in update:
        keys = _key.split('.')
        nested_set(config, keys, update[_key])

    out = dump(config)
    with open(yml, 'wb') as f:
    return Path(yml).absolute()