Skip to content

Commit

Permalink
Various bugfixes
Browse files Browse the repository at this point in the history
* Ensure we calculate rotatable bonds on the version of the ligand with no hydrogens. Also fix spelling of rotable -> rotatable. Closes GH-220 (@Nobody-Zhang)

* Vectorize SO3 calculations. Closes PR GH-218 (@tornikeo)

* Pin pytorch-lightning version. Closes GH-193 (@mikael-h-christensen)

* Guard against divide by zero in torus.py. Closes GH-161 (@amorehead)

* Update e3nn version to 0.5.1. Closes GH-155 (@amorehead)

* Add a little more info on docker container to README.md
  • Loading branch information
jsilter committed Apr 30, 2024
1 parent 2c867df commit 561f70a
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 32 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,30 @@ current repo

To set up an appropriate environment, navigate to the root of the repository and run the following commands:

conda env create --file environment.yml
conda create --file environment.yml
conda activate diffdock

See [conda documentation](https://conda.io/projects/conda/en/latest/commands/env/create.html) for more information.

### Using a Docker container

A Dockerfile is provided for building a container:

docker build -f Dockerfile -t diffdock

Alternatively, you can use a pre-built container to run the code.
First, download the container from Docker Hub:

docker pull rbgcsail/diffdock

Then, run the container:

docker run -it --entrypoint /bin/bash rbgcsail/diffdock
# Inside the container
micromamba activate diffdock

You can now run the code as described below.

### Docking Prediction <a name="inference"></a>

We support multiple input formats depending on whether you only want to make predictions for a single complex or for many at once.\
Expand Down
20 changes: 10 additions & 10 deletions datasets/conformer_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,42 @@ def SetDihedral(conf, atom_idx, new_vale):
rdMolTransforms.SetDihedralRad(conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale)


def apply_changes(mol, values, rotable_bonds, conf_id):
def apply_changes(mol, values, rotatable_bonds, conf_id):
opt_mol = copy.copy(mol)
[SetDihedral(opt_mol.GetConformer(conf_id), rotable_bonds[r], values[r]) for r in range(len(rotable_bonds))]
[SetDihedral(opt_mol.GetConformer(conf_id), rotatable_bonds[r], values[r]) for r in range(len(rotatable_bonds))]
return opt_mol


def optimize_rotatable_bonds(mol, true_mol, rotable_bonds, probe_id=-1, ref_id=-1, seed=0, popsize=15, maxiter=500,
def optimize_rotatable_bonds(mol, true_mol, rotatable_bonds, probe_id=-1, ref_id=-1, seed=0, popsize=15, maxiter=500,
mutation=(0.5, 1), recombination=0.8):
opt = OptimizeConformer(mol, true_mol, rotable_bonds, seed=seed, probe_id=probe_id, ref_id=ref_id)
max_bound = [np.pi] * len(opt.rotable_bonds)
min_bound = [-np.pi] * len(opt.rotable_bonds)
opt = OptimizeConformer(mol, true_mol, rotatable_bonds, seed=seed, probe_id=probe_id, ref_id=ref_id)
max_bound = [np.pi] * len(opt.rotatable_bonds)
min_bound = [-np.pi] * len(opt.rotatable_bonds)
bounds = (min_bound, max_bound)
bounds = list(zip(bounds[0], bounds[1]))

# Optimize conformations
result = differential_evolution(opt.score_conformation, bounds,
maxiter=maxiter, popsize=popsize,
mutation=mutation, recombination=recombination, disp=False, seed=seed)
opt_mol = apply_changes(opt.mol, result['x'], opt.rotable_bonds, conf_id=probe_id)
opt_mol = apply_changes(opt.mol, result['x'], opt.rotatable_bonds, conf_id=probe_id)

return opt_mol


class OptimizeConformer:
def __init__(self, mol, true_mol, rotable_bonds, probe_id=-1, ref_id=-1, seed=None):
def __init__(self, mol, true_mol, rotatable_bonds, probe_id=-1, ref_id=-1, seed=None):
super(OptimizeConformer, self).__init__()
if seed:
np.random.seed(seed)
self.rotable_bonds = rotable_bonds
self.rotatable_bonds = rotatable_bonds
self.mol = mol
self.true_mol = true_mol
self.probe_id = probe_id
self.ref_id = ref_id

def score_conformation(self, values):
for i, r in enumerate(self.rotable_bonds):
for i, r in enumerate(self.rotatable_bonds):
SetDihedral(self.mol.GetConformer(self.probe_id), r, values[i])
return AllChem.AlignMol(self.mol, self.true_mol, self.probe_id, self.ref_id)

Expand Down
12 changes: 8 additions & 4 deletions datasets/process_mols.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching,
positions.append(conf.GetPositions())
complex_graph['ligand'].orig_pos = np.asarray(positions) if len(positions) > 1 else positions[0]

rotable_bonds = get_torsion_angles(mol_maybe_noh)
#if not rotable_bonds: print("no_rotable_bonds but still using it")
# rotatable_bonds = get_torsion_angles(mol_maybe_noh)
_tmp = copy.deepcopy(mol_)
if remove_hs:
_tmp = RemoveHs(_tmp, sanitize=True)
_tmp = AllChem.RemoveAllHs(_tmp)
rotatable_bonds = get_torsion_angles(_tmp)

for i in range(num_conformers):
mols, rmsds = [], []
Expand All @@ -347,8 +351,8 @@ def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching,
mol_rdkit = RemoveHs(mol_rdkit, sanitize=True)
mol_rdkit = AllChem.RemoveAllHs(mol_rdkit)
mol = AllChem.RemoveAllHs(copy.deepcopy(mol_maybe_noh))
if rotable_bonds and not skip_matching:
optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter)
if rotatable_bonds and not skip_matching:
optimize_rotatable_bonds(mol_rdkit, mol, rotatable_bonds, popsize=popsize, maxiter=maxiter)
mol.AddConformer(mol_rdkit.GetConformer())
rms_list = []
AllChem.AlignMolConformers(mol, RMSlist=rms_list)
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ dependencies:
- --extra-index-url https://download.pytorch.org/whl/cu117
- --find-links https://pytorch-geometric.com/whl/torch-1.13.1+cu117.html
- dllogger @ git+https://github.com/NVIDIA/dllogger.git
- e3nn==0.5.0
- e3nn==0.5.1
- fair-esm[esmfold]==2.0.0
- networkx==2.8.4
- openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307
- pandas==1.5.1
- prody==2.2.0
- prody==2.2.0
- pybind11==2.11.1
- pytorch-lightning==1.9.5
- rdkit==2022.03.3
- scikit-learn==1.1.0
- scipy==1.12.0
Expand Down
30 changes: 26 additions & 4 deletions models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'silu': nn.SiLU
}


def FCBlock(in_dim, hidden_dim, out_dim, layers, dropout, activation='relu'):
activation = ACTIVATIONS[activation]
assert layers >= 2
Expand All @@ -29,10 +30,20 @@ def forward(self, dist):
return torch.exp(self.coeff * torch.pow(dist, 2))



class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_dim=0):
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
"""
Parameters
----------
emb_dim
feature_dims
first element of feature_dims tuple is a list with the length of each categorical feature,
and the second is the number of scalar features
sigma_embed_dim
lm_embedding_dim
"""
#
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.num_categorical_features = len(feature_dims[0])
Expand All @@ -58,8 +69,19 @@ def forward(self, x):

class OldAtomEncoder(torch.nn.Module):

def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type= None):
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type=None):
"""
Parameters
----------
emb_dim
feature_dims
first element of feature_dims tuple is a list with the length of each categorical feature,
and the second is the number of scalar features
sigma_embed_dim
lm_embedding_type
"""
#
super(OldAtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.num_categorical_features = len(feature_dims[0])
Expand Down
20 changes: 9 additions & 11 deletions utils/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def _compose(r1, r2): # R1 @ R2 but for Euler vecs


def _expansion(omega, eps, L=2000): # the summation term only
p = 0
for l in range(L):
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
l_vec = np.arange(L).reshape(-1, 1)
p = ((2 * l_vec + 1) * np.exp(-l_vec * (l_vec + 1) * eps ** 2 / 2)
* np.sin(omega * (l_vec + 1 / 2)) / np.sin(omega / 2)).sum(0)
return p


Expand All @@ -33,13 +33,12 @@ def _density(expansion, omega, marginal=True): # if marginal, density over [0,


def _score(exp, omega, eps, L=2000): # score of density over SO(3)
dSigma = 0
for l in range(L):
hi = np.sin(omega * (l + 1 / 2))
dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2))
lo = np.sin(omega / 2)
dlo = 1 / 2 * np.cos(omega / 2)
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * (lo * dhi - hi * dlo) / lo ** 2
l_vec = np.arange(L).reshape(-1, 1)
hi = np.sin((l_vec + 1 / 2) * omega)
dhi = (l_vec + 1 / 2) * np.cos((l_vec + 1 / 2) * omega)
lo = np.sin(omega / 2)
dlo = 1 / 2 * np.cos(omega / 2)
dSigma = ((2 * l_vec + 1) * np.exp(-l_vec * (l_vec + 1) * eps**2 / 2) * (lo * dhi - hi * dlo) / lo ** 2).sum(0)
return dSigma / exp


Expand Down Expand Up @@ -92,4 +91,3 @@ def score_norm(eps):
eps_idx = (np.log10(eps) - np.log10(MIN_EPS)) / (np.log10(MAX_EPS) - np.log10(MIN_EPS)) * N_EPS
eps_idx = np.clip(np.around(eps_idx).astype(int), a_min=0, a_max=N_EPS-1)
return torch.from_numpy(_exp_score_norms[eps_idx]).float()

3 changes: 2 additions & 1 deletion utils/torus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def grad(x, sigma, N=10):
p_ = p(x, sigma[:, None], N=100)
np.save('.p.npy', p_)

score_ = grad(x, sigma[:, None], N=100) / p_
eps = np.finfo(p_.dtype).eps
score_ = grad(x, sigma[:, None], N=100) / (p_ + eps)
np.save('.score.npy', score_)


Expand Down

0 comments on commit 561f70a

Please sign in to comment.