Fine tuning a model#
In this section we look at how to improve a model. We start by using a pre-trained model to predict the energies of several oxide polymorphs. The data here comes from this paper:
Identifying Potential BO2 Oxide Polymorphs for Epitaxial Growth Candidates Mehta, P., Salvador, P. A., & Kitchin, J. R. (2015). Identifying potential BO2 oxide polymorphs for epitaxial growth candidates. ACS Appl. Mater. Interfaces, 6(5), 3630–3639. http://dx.doi.org/10.1021/am4059149.
This data set shows equations of state for several oxide/polymorph combinations. I use this dataset to compare with predictions from OCP
First we get the checkpoint that we want. According to the MODELS the GemNet-OC OC20+OC22 combination has an energy MAE of 0.483 which seems like a good place to start. This model was trained on oxides.
We get this checkpoint here.
%run ../ocp-tutorial.ipynb # load some helper functions
checkpoint = get_checkpoint('GemNet-OC OC20+OC22')
The data we need is provided in supporting-information.json
. That file is embedded in the supporting information for the article, and is provided here in the tutorial. We load this data and explore it a little. The json file provides a dictionary with the structure:
[oxide][polymorph][xc][EOS][configurations]
The first key is a composition, the second is a string for the polymorph structure, the third indicates which XC functional was used (we focus on PBE here), the fourth key is for the Equation of State calculations, and the last key is a list of results for the EOS.
import json
import numpy as np
import matplotlib.pyplot as plt
from ase import Atoms
with open('supporting-information.json', 'rb') as f:
d = json.loads(f.read())
oxides = list(d.keys())
polymorphs = list(d['TiO2'].keys())
oxides, polymorphs
(['SnO2', 'IrO2', 'RuO2', 'TiO2', 'VO2'],
['rutile', 'pyrite', 'columbite', 'brookite', 'fluorite', 'anatase'])
A specific calculation has a lot of details in it. You can use this to recreate the calculations. Shortly we will use these to compare the DFT results to OCP.
d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
{'incar': {'doc': 'INCAR parameters',
'prec': 'Normal',
'isif': 4,
'nbands': 20,
'ibrion': 2,
'gga': 'PE',
'encut': 520.0,
'ismear': 0,
'sigma': 0.001,
'nsw': 50},
'doc': 'JSON representation of a VASP calculation.\n\nenergy is in eV\nforces are in eV/\\AA\nstress is in GPa (sxx, syy, szz, syz, sxz, sxy)\nmagnetic moments are in Bohr-magneton\nThe density of states is reported with E_f at 0 eV.\nVolume is reported in \\AA^3\nCoordinates and cell parameters are reported in \\AA\n\nIf atom-projected dos are included they are in the form:\n{ados:{energy:data, {atom index: {orbital : dos}}}\n',
'potcar': [['O',
'potpaw/O/POTCAR',
'0cf2ce56049ca395c567026b700ed66c94a85161'],
['Ti', 'potpaw/Ti/POTCAR', '51f7f05982d6b4052becc160375a8b8b670177a7']],
'input': {'kpts': [6, 6, 6],
'reciprocal': False,
'xc': 'LDA',
'kpts_nintersections': None,
'setups': None,
'txt': '-',
'gamma': False},
'atoms': {'cell': [[4.3789762519649225, 0.0, 0.0],
[0.0, 4.3789762519649225, 0.0],
[0.0, 0.0, 2.864091775985314]],
'symbols': ['Ti', 'Ti', 'O', 'O', 'O', 'O'],
'tags': [0, 0, 0, 0, 0, 0],
'pbc': [True, True, True],
'positions': [[0.0, 0.0, 0.0],
[2.1894881259824612, 2.1894881259824612, 1.432045887992657],
[1.3181554154438013, 1.3181554154438013, 0.0],
[3.0608208365211214, 3.0608208365211214, 0.0],
[3.5076435414262623, 0.87133271053866, 1.432045887992657],
[0.87133271053866, 3.5076435414262623, 1.432045887992657]]},
'data': {'stress': [496.18519999, 496.18519999, 502.82679392, 0.0, 0.0, 0.0],
'doc': 'Data from the output of the calculation',
'volume': 54.92019999999996,
'total_energy': -56.230672,
'forces': [[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[-0.001264, -0.001264, 0.0],
[0.001264, 0.001264, 0.0],
[-0.001264, 0.001264, 0.0],
[0.001264, -0.001264, 0.0]],
'fermi_level': 3.153}}
For each result we can retrieve the atomic geometry, energy and forces from the json file. We use these to recreate an Atoms object. Here is an example.
c = d['TiO2']['rutile']['PBE']['EOS']['calculations'][0]
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms, c['data']['total_energy'], c['data']['forces']
(Atoms(symbols='Ti2O4', pbc=True, cell=[4.3789762519649225, 4.3789762519649225, 2.864091775985314], tags=...),
-56.230672,
[[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[-0.001264, -0.001264, 0.0],
[0.001264, 0.001264, 0.0],
[-0.001264, 0.001264, 0.0],
[0.001264, -0.001264, 0.0]])
Next, we will create an OCP calculator that we can use to get predictions from.
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
calc = OCPCalculator(checkpoint=checkpoint, trainer='forces', cpu=False)
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
amp: true
cmd:
checkpoint_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/checkpoints/2023-08-11-14-41-04
commit: 3973c79
identifier: ''
logs_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/logs/tensorboard/2023-08-11-14-41-04
print_every: 100
results_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/results/2023-08-11-14-41-04
seed: null
timestamp_id: 2023-08-11-14-41-04
dataset: null
gpus: 1
logger: tensorboard
model: gemnet_oc
model_attributes:
activation: silu
atom_edge_interaction: true
atom_interaction: true
cbf:
name: spherical_harmonics
cutoff: 12.0
cutoff_aeaint: 12.0
cutoff_aint: 12.0
cutoff_qint: 12.0
direct_forces: true
edge_atom_interaction: true
emb_size_aint_in: 64
emb_size_aint_out: 64
emb_size_atom: 256
emb_size_cbf: 16
emb_size_edge: 512
emb_size_quad_in: 32
emb_size_quad_out: 32
emb_size_rbf: 16
emb_size_sbf: 32
emb_size_trip_in: 64
emb_size_trip_out: 64
envelope:
exponent: 5
name: polynomial
extensive: true
forces_coupled: false
max_neighbors: 30
max_neighbors_aeaint: 20
max_neighbors_aint: 1000
max_neighbors_qint: 8
num_after_skip: 2
num_atom: 3
num_atom_emb_layers: 2
num_before_skip: 2
num_blocks: 4
num_concat: 1
num_global_out_layers: 2
num_output_afteratom: 3
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
qint_tags:
- 1
- 2
quad_interaction: true
rbf:
name: gaussian
regress_forces: true
sbf:
name: legendre_outer
symmetric_edge_symmetrization: false
noddp: false
optim:
batch_size: 16
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 5000
factor: 0.8
force_coefficient: 1
load_balancing: atoms
loss_energy: mae
loss_force: atomwisel2
lr_initial: 0.0005
max_epochs: 80
mode: min
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
patience: 3
scheduler: ReduceLROnPlateau
weight_decay: 0
slurm:
additional_parameters:
constraint: volta32gb
cpus_per_task: 3
folder: /checkpoint/abhshkdz/ocp_oct1_logs/57632342
gpus_per_node: 8
job_id: '57632342'
job_name: gnoc_oc22_oc20_all_s2ef
mem: 480GB
nodes: 8
ntasks_per_node: 8
partition: ocp,learnaccel
time: 4320
task:
dataset: oc22_lmdb
description: Regressing to energies and forces for DFT trajectories from OCP
eval_on_free_atoms: true
grad_input: atomic forces
labels:
- potential energy
metric: mae
primary_metric: forces_mae
train_on_free_atoms: true
type: regression
trainer: forces
Now, we loop through each structure and accumulate the OCP predictions. Then, we plot the parity results.
import time
t0 = time.time()
eos_data = {}
for oxide in oxides:
eos_data[oxide] = {}
for polymorph in polymorphs:
dft = []
ocp = []
vols = []
calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
for c in calculations:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms.calc = calc
ocp += [atoms.get_potential_energy() / len(atoms)]
dft += [c['data']['total_energy'] / len(atoms)]
vols += [atoms.get_volume()]
plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
alpha=0.5, label=f'{oxide}-{polymorph}')
eos_data[oxide][polymorph] = (vols, dft, ocp)
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
print(f'Elapsed time {time.time() - t0:1.1f} seconds.')
Elapsed time 19.7 seconds.
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'MAE = {mae:1.3f} eV/atom')
MAE = 1.668 eV/atom
The MAE is somewhat high compared to the reported value of 0.458 eV. That is not too surprising; although OC22 was trained on oxides, it was not trained on all of these structures. It is also evident on inspection that the main issues are all the VO2 structures, and these skew the MAE. Next we look at the EOS for each material. Here you can see offsets, and qualitative issues in the shapes. The offset is largely a PBE/RPBE difference.
Some notable issues are with fluorite structures. We show one here for VO2.
oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'pretrained {oxide} - {polymorph}')
plt.legend();
We can improve this by fine-tuning the model.
Fine tuning the checkpoint#
There are severeal steps in fine-tuning:
Create train, test, val datasets
Create a configuration yml file
Run the training
Check the results
First we create the training databases. OCP supports ase databases, so we use those here. The first step is to make a database containing the data. We need to write an atoms object with an attached SinglePointCalculator for each structure.
from ase.db import connect
from ase.calculators.singlepoint import SinglePointCalculator
! rm -fr oxides.db # start clean when you run this cell
db = connect('oxides.db')
for oxide in oxides:
for polymorph in polymorphs:
for c in d[oxide][polymorph]['PBE']['EOS']['calculations']:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
calc = SinglePointCalculator(atoms,
energy=c['data']['total_energy'],
forces=c['data']['forces'])
atoms.set_calculator(calc)
db.write(atoms)
Let’s see what we made.
! ase db oxides.db
id|age|formula|calculator| energy|natoms| fmax|pbc| volume|charge| mass
1| 4s|Sn2O4 |unknown |-41.359| 6|0.045|TTT| 64.258| 0.000|301.416
2| 4s|Sn2O4 |unknown |-41.853| 6|0.025|TTT| 66.526| 0.000|301.416
3| 4s|Sn2O4 |unknown |-42.199| 6|0.010|TTT| 68.794| 0.000|301.416
4| 4s|Sn2O4 |unknown |-42.419| 6|0.006|TTT| 71.062| 0.000|301.416
5| 4s|Sn2O4 |unknown |-42.534| 6|0.011|TTT| 73.330| 0.000|301.416
6| 4s|Sn2O4 |unknown |-42.562| 6|0.029|TTT| 75.598| 0.000|301.416
7| 4s|Sn2O4 |unknown |-42.518| 6|0.033|TTT| 77.866| 0.000|301.416
8| 4s|Sn2O4 |unknown |-42.415| 6|0.010|TTT| 80.134| 0.000|301.416
9| 4s|Sn2O4 |unknown |-42.266| 6|0.006|TTT| 82.402| 0.000|301.416
10| 4s|Sn2O4 |unknown |-42.083| 6|0.017|TTT| 84.670| 0.000|301.416
11| 4s|Sn4O8 |unknown |-81.424| 12|0.012|TTT|117.473| 0.000|602.832
12| 4s|Sn4O8 |unknown |-82.437| 12|0.005|TTT|121.620| 0.000|602.832
13| 4s|Sn4O8 |unknown |-83.147| 12|0.015|TTT|125.766| 0.000|602.832
14| 4s|Sn4O8 |unknown |-83.599| 12|0.047|TTT|129.912| 0.000|602.832
15| 4s|Sn4O8 |unknown |-83.831| 12|0.081|TTT|134.058| 0.000|602.832
16| 4s|Sn4O8 |unknown |-83.898| 12|0.001|TTT|138.204| 0.000|602.832
17| 4s|Sn4O8 |unknown |-83.805| 12|0.001|TTT|142.350| 0.000|602.832
18| 4s|Sn4O8 |unknown |-83.586| 12|0.002|TTT|146.496| 0.000|602.832
19| 4s|Sn4O8 |unknown |-83.262| 12|0.002|TTT|150.642| 0.000|602.832
20| 4s|Sn4O8 |unknown |-82.851| 12|0.013|TTT|154.788| 0.000|602.832
Rows: 295 (showing first 20)
Make the train, test, val splits#
We need to split the ase-db into three separate databases, one for training (80%), one for testing (10%) and one for validation. We generate a list of ids and then shuffle them. Then we write the first 80% into train.db
, the next 10% into test.db
, and the remaining into val.db
.
The train set is used for training. The test and val sets are used to check for overfitting.
You choose the splits you want, 80:10:10 is common. We take a simple approach to split the database here. We make an array of integers that correspond to the ids, randomly shuffle them, and then get each row in the randomized order and write them to a new db.
We provide some helper functions in ../ocp-tutorial.ipynb to streamline this process.
! rm -fr train.db test.db val.db
train, test, val = train_test_val_split('oxides.db')
train, test, val
(PosixPath('/home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/train.db'),
PosixPath('/home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/test.db'),
PosixPath('/home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/val.db'))
Setting up the configuration yaml file#
We have to create a yaml configuration file for the model we are using. The pre-trained checkpoints contain their config data, so we use this to get the base configuration, and then remove pieces we don’t need, and update pieces we do need.
yml = generate_yml_config(checkpoint, 'config.yml',
delete=['slurm', 'cmd', 'logger', 'task', 'model_attributes',
'optim.loss_force', # the checkpoint setting causes an error
'dataset', 'test_dataset', 'val_dataset'],
update={'gpus': 1,
'task.dataset': 'ase_db',
'optim.eval_every': 1,
'optim.max_epochs': 10,
# Train data
'dataset.train.src': 'train.db',
'dataset.train.a2g_args.r_energy': True,
'dataset.train.a2g_args.r_forces': True,
# Test data - prediction only so no regression
'dataset.test.src': 'test.db',
'dataset.test.a2g_args.r_energy': False,
'dataset.test.a2g_args.r_forces': False,
# val data
'dataset.val.src': 'val.db',
'dataset.val.a2g_args.r_energy': True,
'dataset.val.a2g_args.r_forces': True,
})
yml
WARNING:root:Unable to identify OCP trainer, defaulting to `forces`. Specify the `trainer` argument into OCPCalculator if otherwise.
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
PosixPath('/home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/config.yml')
! cat config.yml
amp: true
checkpoint: gnoc_oc22_oc20_all_s2ef.pt
dataset:
test:
a2g_args:
r_energy: false
r_forces: false
src: test.db
train:
a2g_args:
r_energy: true
r_forces: true
src: train.db
val:
a2g_args:
r_energy: true
r_forces: true
src: val.db
gpus: 1
model:
activation: silu
atom_edge_interaction: true
atom_interaction: true
cbf:
name: spherical_harmonics
cutoff: 12.0
cutoff_aeaint: 12.0
cutoff_aint: 12.0
cutoff_qint: 12.0
direct_forces: true
edge_atom_interaction: true
emb_size_aint_in: 64
emb_size_aint_out: 64
emb_size_atom: 256
emb_size_cbf: 16
emb_size_edge: 512
emb_size_quad_in: 32
emb_size_quad_out: 32
emb_size_rbf: 16
emb_size_sbf: 32
emb_size_trip_in: 64
emb_size_trip_out: 64
envelope:
exponent: 5
name: polynomial
extensive: true
forces_coupled: false
max_neighbors: 30
max_neighbors_aeaint: 20
max_neighbors_aint: 1000
max_neighbors_qint: 8
name: gemnet_oc
num_after_skip: 2
num_atom: 3
num_atom_emb_layers: 2
num_before_skip: 2
num_blocks: 4
num_concat: 1
num_global_out_layers: 2
num_output_afteratom: 3
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
qint_tags:
- 1
- 2
quad_interaction: true
rbf:
name: gaussian
regress_forces: true
sbf:
name: legendre_outer
symmetric_edge_symmetrization: false
noddp: false
optim:
batch_size: 16
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 1
factor: 0.8
force_coefficient: 1
load_balancing: atoms
loss_energy: mae
lr_initial: 0.0005
max_epochs: 10
mode: min
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
patience: 3
scheduler: ReduceLROnPlateau
weight_decay: 0
task:
dataset: ase_db
trainer: forces
Running the training job#
ocp
provides a main.py
file that is used for training. Here we construct the Python command you need to run, and run it. main.py
is not executable, so we have to run it with python, and you need the absolute path to it, which we get from the ocp_main()
that is defined in the utilities.
you must set a mode
and provide a config-yml
. We provide a checkpoint for a starting point, if you don’t do this, it will start from scratch.
By default the files that are created go in a directories called checkpoints
, logs
and results
. You can change that with a --run-dir preferred-place
option.
You can also add an identifier tag to the end of the timestamps with --identifier tag
. This can make it easier to find them later.
The cell below uses some IPython magic to put Python variables in the shell command.
This command takes some time to run even on a GPU, e.g. about 30 minutes.
It is advisable to redirect the outputs to files. The reason is that if the outputs are very large, the notebook may not be able to be saved. This also makes your notebooks more reproducible. The checkpoints are stored in time-stamped directories that change everytime you run them. Below we show how to reproducibly retrieve this directory name in a way that allows you to run the notebook again later, while automatically updating the directory name.
You can follow how the training is going by opening a terminal and running
tail -f train.txt
You can also visit it in a browser at train.txt. You have to periodically refresh the view to see updates though.
This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens.
import time
t0 = time.time()
! python {ocp_main()} --mode train --config-yml $yml --checkpoint $checkpoint \
--run-dir fine-tuning --identifier ft-oxides --amp > train.txt 2>&1
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
Show code cell output
Elapsed time = 251.4 seconds
Now since we have a file, we can find the training results in it. See train.txt. At the top, the config is printed, so we can get the checkpoint directory. I use shell commands and Python to get the line, split and strip it here.
cpline = !grep "checkpoint_dir:" train.txt
cpdir = cpline[0].split(':')[-1].strip()
cpdir
'fine-tuning/checkpoints/2023-08-11-14-41-04-ft-oxides'
There will be two files in there: checkpoint.pt
and best_checkpoint.pt
.
The best_checkpoint.pt
is the one that performs best on the validation dataset. The checkpoint.pt
is the most recently saved one. Probably it has the lowest loss in training, but this could be an indication of overfitting. You have to use some judgement in determining which one to use, and if it is sufficiently accurate for your needs. It may need additional training to further improve it.
newckpt = cpdir + '/checkpoint.pt'
newcalc = OCPCalculator(checkpoint=newckpt, cpu=False)
WARNING:root:Unrecognized arguments: ['symmetric_edge_symmetrization']
amp: true
cmd:
checkpoint_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/checkpoints/2023-08-11-14-45-20
commit: 3973c79
identifier: ''
logs_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/logs/tensorboard/2023-08-11-14-45-20
print_every: 100
results_dir: /home/jovyan/shared-scratch/jkitchin/tutorial/ocp-tutorial/fine-tuning/results/2023-08-11-14-45-20
seed: null
timestamp_id: 2023-08-11-14-45-20
dataset: null
gpus: 1
logger: tensorboard
model: gemnet_oc
model_attributes:
activation: silu
atom_edge_interaction: true
atom_interaction: true
cbf:
name: spherical_harmonics
cutoff: 12.0
cutoff_aeaint: 12.0
cutoff_aint: 12.0
cutoff_qint: 12.0
direct_forces: true
edge_atom_interaction: true
emb_size_aint_in: 64
emb_size_aint_out: 64
emb_size_atom: 256
emb_size_cbf: 16
emb_size_edge: 512
emb_size_quad_in: 32
emb_size_quad_out: 32
emb_size_rbf: 16
emb_size_sbf: 32
emb_size_trip_in: 64
emb_size_trip_out: 64
envelope:
exponent: 5
name: polynomial
extensive: true
forces_coupled: false
max_neighbors: 30
max_neighbors_aeaint: 20
max_neighbors_aint: 1000
max_neighbors_qint: 8
num_after_skip: 2
num_atom: 3
num_atom_emb_layers: 2
num_before_skip: 2
num_blocks: 4
num_concat: 1
num_global_out_layers: 2
num_output_afteratom: 3
num_radial: 128
num_spherical: 7
otf_graph: true
output_init: HeOrthogonal
qint_tags:
- 1
- 2
quad_interaction: true
rbf:
name: gaussian
regress_forces: true
sbf:
name: legendre_outer
symmetric_edge_symmetrization: false
noddp: false
optim:
batch_size: 16
clip_grad_norm: 10
ema_decay: 0.999
energy_coefficient: 1
eval_batch_size: 16
eval_every: 1
factor: 0.8
force_coefficient: 1
load_balancing: atoms
loss_energy: mae
lr_initial: 0.0005
max_epochs: 10
mode: min
num_workers: 2
optimizer: AdamW
optimizer_params:
amsgrad: true
patience: 3
scheduler: ReduceLROnPlateau
weight_decay: 0
slurm: {}
task:
dataset: ase_db
trainer: forces
eos_data = {}
for oxide in oxides:
eos_data[oxide] = {}
for polymorph in polymorphs:
dft = []
ocp = []
vols = []
calculations = d[oxide][polymorph]['PBE']['EOS']['calculations']
for c in calculations:
atoms = Atoms(symbols=c['atoms']['symbols'],
positions=c['atoms']['positions'],
cell=c['atoms']['cell'],
pbc=c['atoms']['pbc'])
atoms.set_tags(np.ones(len(atoms)))
atoms.calc = newcalc
ocp += [atoms.get_potential_energy() / len(atoms)]
dft += [c['data']['total_energy'] / len(atoms)]
vols += [atoms.get_volume()]
plt.plot(dft, ocp, marker='s' if oxide == 'VO2' else '.',
alpha=0.5, label=f'{oxide}-{polymorph}')
eos_data[oxide][polymorph] = (vols, dft, ocp)
plt.xlabel('DFT (eV/atom)')
plt.ylabel('OCP (eV/atom)')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=3);
mae = np.mean(np.abs(np.array(dft) - np.array(ocp)))
print(f'New MAE = {mae:1.3f} eV/atom')
New MAE = 0.030 eV/atom
This is a substantial improvement! We can see what it means in terms of the Equations of State. There is still not perfect agreement, but the curves are closer together. Additional fine tuning, or a better model could probably still improve this.
oxide, polymorph = 'VO2', 'fluorite'
V, D, O = eos_data[oxide][polymorph]
plt.plot(V, D, label='dft')
plt.plot(V, O, label='ocp')
plt.title(f'fine-tuned {oxide} - {polymorph}')
plt.legend();
It is possible to continue refining the fit. The simple things to do are to use more epochs of training. Eventually the MAE will stabilize, and then it may be necessary to adjust other optimization parameters like the learning rate (usually you decrease it).
Depending on what is important to you, you may consider changing the relative importances of energy and forces; you can often trade off accuracy of one for the other. It may be necessary to add additional data to cover the composition and configuration space more thoroughly.
There are also other models you could consider. Newer models tend to be more accurate, but they may also be more expensive to run. These are all compromises you can consider.
Next steps#
You might want to check out the gotchas notebook. It has some examples of things we have seen go wrong before.
The advanced folder contains several concepts that are more advanced. These include:
Mass inference which is more efficient than the ASE calculator
Working with embeddings which shows some ways to use the embeddings from OCP