DTI 1b: Interim Drug System Tables

Processing molecules with RDKit.

Published

May 16, 2024

This release of the drug-target interaction series will focus on building the drug system tables – Ligand_Molecule, Ligand_Atom, and Ligand_Bond.

erDiagram
    Reaction
    Target

    Ligand_Molecule
    Ligand_Atom
    Ligand_Bond

    Reaction }o--|| Ligand_Molecule : "reacts"
    Reaction }o--|| Target : "reacts"
    Ligand_Molecule ||--|{ Ligand_Atom : "contains"
    Ligand_Bond }|--|{ Ligand_Atom : "bonds"

Table schemas

The schemas of the processed tables are inspired by prior work on graph representations of molecules, namely Chen et al., “Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals.” and Kearnes et al., “Molecular Graph Convolutions.”.

Molecule

erDiagram
    Molecule {
        int id PK
        varchar name
        varchar smiles
        int n_atoms
        int n_bonds
        float weight
        float mean_atomic_weight "Average of average (not exact) weights of each atom within the molecule"
        float bonds_per_atom "Average bound count of each atom within the molecule"
    }

The interim Molecule table will exclude mean_atomic_weight and bonds_per_atom. These features can be calculated for all molecules in a vectorized query once the atom- and bond-level features exist in table formats. This is faster than generating these features during the ingestion loop.

Atom

erDiagram
    Atom {
        int molecule_id PK, FK
        int index PK
        categorical symbol
        categorical chirality
        int ring_size_[n]_count
        categorical hybridization
        bool acceptor
        bool donor
        bool aromatic
        float x
        float y
        float z
    }

The interim Atom table will exclude the ring_size_[n]_count fields. These features can be calculated for all atoms in a vectorized query using the Ring table and then be joined to the interim Atom table. This is faster than generating these features during the ingestion loop.

Bond

erDiagram
    Bond {
        int molecule_id PK, FK
        int index PK
        int atom1_index FK
        int atom2_index FK
        categorical type
        bool same_ring
    }

The interim Bond table will exclude the same_ring field. This feature can be calculated for all bonds in a vectorized query using the Ring table and then be joined to the interim bond table. This is faster than generating this feature during the ingestion loop.

Ring

erDiagram
    Ring {
        int molecule_id PK, FK
        int index PK
        int size
        int[] atom_indices
    }

The ring table is needed to for generating the ring_size_[n]_count atom features and same_ring bond feature after the ingestion loop.

Data flow diagram

The interim ingest flow includes an inner and outer loop. The inner loop reads a single molecule from the RDKit molecule supplier, processes the molecule, atom, bond, and ring features, and then appends the processed records to a list. The outer loop writes batches of processed records to file. This avoids an IO operation for each molecule, balancing IO and processing loads. A follow-up process (not shown below) with convert these interim tables to their final schemas.

%%{init: {'theme':'forest'}}%%
flowchart LR

mols[|borders:tb|molecule supplier]

proc_mol((process molecule))

mol[|borders:tb|interim molecule]
atom[|borders:tb|interim atom]
bond[|borders:tb|interim bond]
ring[|borders:tb|ring]

acc((accumulate))
write((write processed batch))

mol_batch[|borders:tb|interim molecule batch]
atom_batch[|borders:tb|interim atom batch]
bond_batch[|borders:tb|interim bond batch]
ring_batch[|borders:tb|interim ring batch]

subgraph inner[inner loop]
    mols --> proc_mol
    proc_mol --> mol
    proc_mol --> atom
    proc_mol --> bond
    proc_mol --> ring

    mol --> acc
    atom --> acc
    bond --> acc
    ring --> acc
end

acc --> write

subgraph outer[outer loop]
    write --> mol_batch
    write --> atom_batch
    write --> bond_batch
    write --> ring_batch
end

classDef Transparent fill:#FFFFFF;
class inner,outer Transparent;

Sourcing

from rdkit import Chem

bdb_path = raw_data_path / "BindingDB_All_3D_202404.sdf"
supplier = Chem.SDMolSupplier(bdb_path)

mol = supplier[0]
mol
[18:33:06] Warning: molecule is tagged as 2D, but at least one Z coordinate is not zero. Marking the mol as 3D.
From www.bindingDB.org
BindingDB Reactant_set_id 1
Ligand InChI InChI=1S/C31H42N2O7/c34-27(35)17-9-3-11-19-32-25(21-23-13-5-1-6-14-23)29(38)30(39)26(22-24-15-7-2-8-16-24)33(31(32)40)20-12-4-10-18-28(36)37/h1-2,5-8,13-16,25-26,29-30,38-39H,3-4,9-12,17-22H2,(H,34,35)(H,36,37)/t25-,26-,29+,30+/m1/s1
Ligand InChI Key XGEGDSLAQZJGCW-HHGOQMMWSA-N
BindingDB MonomerID 608734
BindingDB Ligand Name 6-[(4R,5S,6S,7R)-4,7-dibenzyl-3-(5-carboxypentyl)-5,6-dihydroxy-2-oxo-1,3-diazepan-1-yl]hexanoic acid::DMPC Cyclic Urea 1
Target Name Dimer of Gag-Pol polyprotein [501-599]
Target Source Organism According to Curator or DataSource Human immunodeficiency virus 1
Ki (nM) 0.24
IC50 (nM)
Property list truncated.
Increase IPythonConsole.ipython_maxProperties (or set it to -1) to see more properties.

Engineering

Molecule

Code
import numpy as np
from collections import Counter

def get_mol_prop(mol: Chem.Mol, prop: str, default=None):
    try:
        return mol.GetProp(prop)
    except KeyError:
        return default


def get_mean_mol_atomic_weight(mol) -> float:
    pse = Chem.GetPeriodicTable()
    return np.mean([pse.GetAtomicWeight(atom.GetAtomicNum()) for atom in mol.GetAtoms()])


def get_mol_bonds_per_atom(mol) -> float:
    bond_count_by_atom_index = Counter(
        [
            atom_index
            for atom_index_pair in [
                (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) for bond in mol.GetBonds()
            ]
            for atom_index in atom_index_pair
        ]
    )
    return np.mean(list(bond_count_by_atom_index.values()))

Some needed molecule features exist as RDKit Molecule attributes or can be captured with RDKit methods. Others need custom methods.

Wrapping these feature calls in a function allows easy molecule record generation.

Feature generation for all drug system tables has been split into a lower-level function, build_{...}_record, and higher-level function, build_{...}_df. The lower-level function, build_{...}_record, works directly with structured Numpy arrays. This avoids the overhead associated with Pandas DataFrame creation, namely indexing and type inference.

import pandas as pd


def build_mol_record(mol, supplier_index: Optional[int] = None) -> np.ndarray:
    id = get_mol_id(mol)
    name = get_mol_prop(mol, "BindingDB Ligand Name")
    smiles = Chem.MolToSmiles(mol)

    # (name, value, dtype)
    dtypes = [
        ("id", "int"),
        ("supplier_index", "int"),
        ("name", f"U{len(name)}"),
        ("smiles", f"U{len(smiles)}"),
        ("n_atoms", "int"),
        ("n_bonds", "int"),
        ("weight", "float"),
        ("mean_atomic_weight", "float"),
        ("bonds_per_atom", "float"),
    ]

    n_atoms = mol.GetNumAtoms()
    n_bonds = mol.GetNumBonds()

    values = [
        [
            id,
            supplier_index,
            name,
            smiles,
            n_atoms,
            n_bonds,
            ExactMolWt(mol),
            get_mean_mol_atomic_weight(mol),
            get_mol_bonds_per_atom(mol),
        ]
    ]

    record = (
        np.array([(values[0][0], *values[0][2:])], dtype=[dtypes[0], *dtypes[2:]])
        if supplier_index is None
        else np.array([tuple(values[0])], dtype=dtypes)
    )
    return record


def build_mol_df(mol, supplier_index: Optional[int] = None) -> pd.DataFrame:
    return pd.DataFrame(build_mol_record(mol, supplier_index=supplier_index))
build_mol_df(mol, supplier_index=0)
id supplier_index name smiles n_atoms n_bonds weight mean_atomic_weight bonds_per_atom
0 608734 0 6-[(4R,5S,6S,7R)-4,7-dibenzyl-3-(5-carboxypent... O=C(O)CCCCCN1C(=O)N(CCCCCC(=O)O)[C@H](Cc2ccccc... 40 42 554.299202 12.8087 1.05

The mean_atomic_weight and bonds_per_atom features require looping through each atom within the molecule. The interim Molecule table removes these features since they can be calculated in a vectorized way once the Atom and Bond tables have been created.

# Remove atom-aggregate features that require looping
# These will be more efficient to capture from the atom table itself (once built)
def build_interim_mol_record(mol, supplier_index: Optional[int] = None) -> np.ndarray:
    id = get_mol_id(mol)
    name = get_mol_prop(mol, "BindingDB Ligand Name")
    smiles = Chem.MolToSmiles(mol)

    # (name, value, dtype)
    dtypes = [
        ("id", "int"),
        ("supplier_index", "int"),
        ("name", f"U{len(name)}"),
        ("smiles", f"U{len(smiles)}"),
        ("n_atoms", "int"),
        ("n_bonds", "int"),
        ("weight", "float"),
    ]
    values = [
        [
            id,
            supplier_index,
            name,
            smiles,
            mol.GetNumAtoms(),
            mol.GetNumBonds(),
            ExactMolWt(mol),
        ]
    ]

    record = (
        np.array([(values[0][0], *values[0][2:])], dtype=[dtypes[0], *dtypes[2:]])
        if supplier_index is None
        else np.array([tuple(values[0])], dtype=dtypes)
    )
    return record


def build_interim_mol_df(mol, supplier_index: Optional[int] = None) -> pd.DataFrame:
    return pd.DataFrame(build_interim_mol_record(mol, supplier_index=supplier_index))
build_interim_mol_df(mol, supplier_index=0)
id supplier_index name smiles n_atoms n_bonds weight
0 608734 0 6-[(4R,5S,6S,7R)-4,7-dibenzyl-3-(5-carboxypent... O=C(O)CCCCCN1C(=O)N(CCCCCC(=O)O)[C@H](Cc2ccccc... 40 42 554.299202

We see a 3x speed improvement between the fastest molecule feature generation function, build_interim_mol_record, and the slowest, build_mol_df. This isn’t an apples-to-apples comparison. The interim record still requires the mean_atom_weight and bonds_per_atom features. That said, it’s unlikely generating these features in a vectorized way will be slower than doing so in a loop per molecule.

%%timeit
build_interim_mol_record(mol, supplier_index=0)
98.3 µs ± 382 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%%timeit
build_mol_df(mol, supplier_index=0)
333 µs ± 3.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Ring

Both the atom and bond tables require ring features. It will be faster to construct these from a separate ring table using vectorized queries rather than constructing them per atom or bond.

Code
def build_ring_record(mol: Chem.Mol) -> np.ndarray:
    molecule_id = get_mol_id(mol)
    mol_rings = mol.GetRingInfo().AtomRings()

    # (name, value, dtype)
    dtypes = [
        ("molecule_id", "int"),
        ("index", "int"),
        ("size", "int"),
        ("atom_indices", "O"),
    ]
    values = [
        (molecule_id, index, len(atom_indices), atom_indices)
        for index, atom_indices in enumerate(mol_rings)
    ]

    return np.array(values, dtype=dtypes)


def build_ring_df(mol: Chem.Mol) -> pd.DataFrame:
    return pd.DataFrame(build_ring_record(mol))
ring_df = build_ring_df(mol)
ring_df
molecule_id index size atom_indices
0 608734 0 7 (8, 9, 10, 19, 27, 29, 31)
1 608734 1 6 (22, 23, 24, 25, 26, 21)
2 608734 2 6 (34, 35, 36, 37, 38, 33)

Examples use

The Ring table will be used to generate the ring_size_[n]_count atom-level features and same_ring bond-level feature. Sample logic for generating these features is shown below.

ring_size_[n]_count Atom feature:

ring_size_range = (3, 8)
ring_size_count_columns = [
    f"ring_size_{i}_count" for i in range(ring_size_range[0], ring_size_range[1] + 1)
]

(
    ring_df.explode("atom_indices")
    .rename(columns={"atom_indices": "atom_index"})
    .groupby(["molecule_id", "atom_index", "size"])
    .count()
    .reset_index()
    .pivot(index=["molecule_id", "atom_index"], columns="size", values="index")
    .pipe(lambda df: df.rename(columns={i: f"ring_size_{i}_count" for i in df.columns}))
    .fillna(0)
    .astype(int)
    .pipe(
        lambda df: df.assign(
            **{
                ring_size_count_col: df[ring_size_count_col]
                if ring_size_count_col in df.columns
                else 0
                for ring_size_count_col in ring_size_count_columns
            }
        )
    )[ring_size_count_columns]
    .reset_index()
    .rename(columns={"atom_index": "index"})
    .rename_axis(columns="")
)
molecule_id index ring_size_3_count ring_size_4_count ring_size_5_count ring_size_6_count ring_size_7_count ring_size_8_count
0 608734 8 0 0 0 0 1 0
1 608734 9 0 0 0 0 1 0
2 608734 10 0 0 0 0 1 0
3 608734 19 0 0 0 0 1 0
4 608734 21 0 0 0 1 0 0
5 608734 22 0 0 0 1 0 0
6 608734 23 0 0 0 1 0 0
7 608734 24 0 0 0 1 0 0
8 608734 25 0 0 0 1 0 0
9 608734 26 0 0 0 1 0 0
10 608734 27 0 0 0 0 1 0
11 608734 29 0 0 0 0 1 0
12 608734 31 0 0 0 0 1 0
13 608734 33 0 0 0 1 0 0
14 608734 34 0 0 0 1 0 0
15 608734 35 0 0 0 1 0 0
16 608734 36 0 0 0 1 0 0
17 608734 37 0 0 0 1 0 0
18 608734 38 0 0 0 1 0 0

same_ring Bond feature:

# `same_ring` bond feature
atom1_index = 8
atom2_index = 9

len(
    rings_df.query(
        f"atoms.astype('str').str.contains('{atom1_index}') & atoms.astype('str').str.contains('{atom2_index}')"
    )
) > 0
True

Atom

The Atom table includes hydrogen bond acceptor and donor boolean features. RDKit doesn’t include native methods for capturing these features. The custom functions use SMARTS patterns and sub-querying to generate these features.

def get_hbd_atom_idxs(mol):
    hbd_smarts_pat = "[N&!H0&v3,N&!H0&+1&v4,O&H1&+0,S&H1&+0,n&H1&+0]"
    hbd_query_mol = Chem.MolFromSmarts(hbd_smarts_pat)
    return [
        index
        for index_group in list(mol.GetSubstructMatches(hbd_query_mol))
        for index in index_group
    ]


def get_hba_atom_idxs(mol):
    hba_smarts_pat = (
        "[$([O,S;H1;v2]-[!$(*=[O,N,P,S])]),$([O,S;H0;v2]),"
        "$([O,S;-]),$([N;v3;!$(N-*=!@[O,N,P,S])]),$([nH0,o,s;+0])]"
    )
    hba_query_mol = Chem.MolFromSmarts(hba_smarts_pat)
    return [
        index
        for index_group in list(mol.GetSubstructMatches(hba_query_mol))
        for index in index_group
    ]

mol_hbd_atom_indices = get_hbd_atom_idxs(mol)
mol_hba_atom_indices = get_hba_atom_idxs(mol)

print("Hydrogen bond donor atom indices: ", mol_hbd_atom_indices)
print("Hydrogen bond acceptor atom indices: ", mol_hba_atom_indices)
Hydrogen bond donor atom indices:  [13, 14]
Hydrogen bond acceptor atom indices:  [8, 13, 14]

RDKit includes methods for molecule-level counts of hydrogen donors and acceptors that can be used to validate the custom methods.

from rdkit.Chem import Lipinski

assert len(mol_hbd_atom_indices) == Lipinski.NumHDonors(mol)
assert len(mol_hba_atom_indices) == Lipinski.NumHAcceptors(mol)

These and other features are abstracted within Atom record building functions for ease of processing.

Code
def build_atom_record(mol, pse=None) -> np.ndarray:
    molecule_id = get_mol_id(mol)
    pse = Chem.GetPeriodicTable() if pse is None else pse
    atom_chirality_map = dict(Chem.FindMolChiralCenters(mol))
    ring_size_range = (3, 8)
    rings_by_size = get_mol_rings_by_size(mol)
    mol_hbd_atom_indexs = get_hbd_atom_idxs(mol)
    mol_hba_atom_indexs = get_hba_atom_idxs(mol)

    dtypes = [
        ("molecule_id", "int"),
        ("index", "int"),
        ("symbol", "<U8"),
        ("weight", "float"),
        ("chirality", "O"),
        *[
            (f"ring_size_{size}_count", "O")
            for size in range(ring_size_range[0], ring_size_range[1] + 1)
        ],
        ("hybridization", "O"),
        ("acceptor", "bool"),
        ("donor", "bool"),
        ("aromatic", "bool"),
        ("x", "float"),
        ("y", "float"),
        ("z", "float"),
    ]

    values = []
    for index, atom in enumerate(mol.GetAtoms()):
        symbol = pse.GetElementSymbol(atom.GetAtomicNum())
        weight = pse.GetAtomicWeight(atom.GetAtomicNum())
        chirality = atom_chirality_map.get(atom.GetIdx(), None)
        ring_counts = list(
            get_atom_ring_counts_by_size(
                mol, atom, rings_by_size=rings_by_size, ring_size_range=ring_size_range
            ).values()
        )
        rdk_hybridization = str(atom.GetHybridization())
        hybridization = (
            {"SP2D": "SP2", "SP3D": "SP3", "SP3D2": "SP3"}.get(rdk_hybridization, None)
            if rdk_hybridization not in ["SP", "SP2", "SP3"]
            else rdk_hybridization
        )
        acceptor = atom.GetIdx() in mol_hba_atom_indexs
        donor = atom.GetIdx() in mol_hbd_atom_indexs
        aromatic = atom.GetIsAromatic()
        positions = mol.GetConformer().GetAtomPosition(index)
        x = positions.x
        y = positions.y
        z = positions.z

        values.append(
            (
                molecule_id,
                index,
                symbol,
                weight,
                chirality,
                *ring_counts,
                hybridization,
                acceptor,
                donor,
                aromatic,
                x,
                y,
                z,
            )
        )

    return np.array(values, dtype=dtypes)


def build_atom_df(mol, pse=None):
    return pd.DataFrame(build_atom_record(mol, pse=pse)).pipe(
        lambda df: df.astype({col: "Int16" for col in df.columns if "ring_size" in col})
    )
pse = Chem.GetPeriodicTable()
build_atom_df(mol, pse=pse)
molecule_id index symbol weight chirality ring_size_3_count ring_size_4_count ring_size_5_count ring_size_6_count ring_size_7_count ring_size_8_count hybridization acceptor donor aromatic x y z
0 608734 0 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 True False False -1.011 3.174 -6.577
1 608734 1 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 False False False -2.049 3.469 -6.009
2 608734 2 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 False True False -2.863 4.337 -6.577
3 608734 3 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -2.393 2.867 -4.752
4 608734 4 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -2.502 3.929 -3.645
5 608734 5 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -2.880 3.267 -2.312
6 608734 6 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -2.951 4.314 -1.190
7 608734 7 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -3.229 3.648 0.169
8 608734 8 N 14.007 None <NA> <NA> <NA> <NA> 1 <NA> SP2 False False False -2.103 2.881 0.635
9 608734 9 C 12.011 None <NA> <NA> <NA> <NA> 1 <NA> SP2 False False False -0.996 3.554 1.092
10 608734 10 N 14.007 None <NA> <NA> <NA> <NA> 1 <NA> SP2 False False False 0.254 3.034 1.322
11 608734 11 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 1.139 3.741 2.209
12 608734 12 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 0.643 3.635 3.660
13 608734 13 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 1.642 4.286 4.627
14 608734 14 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 1.136 4.164 6.072
15 608734 15 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 2.124 4.825 7.045
16 608734 16 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 False False False 1.667 4.702 8.401
17 608734 17 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 False True False 0.526 5.250 8.771
18 608734 18 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 True False False 2.330 4.101 9.230
19 608734 19 C 12.011 R <NA> <NA> <NA> <NA> 1 <NA> SP3 False False False 0.808 1.949 0.558
20 608734 20 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False 0.928 2.329 -0.941
21 608734 21 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 1.788 3.510 -1.166
22 608734 22 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 3.167 3.452 -0.879
23 608734 23 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 3.984 4.577 -1.080
24 608734 24 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 3.429 5.770 -1.573
25 608734 25 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 2.057 5.836 -1.867
26 608734 26 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True 1.241 4.710 -1.666
27 608734 27 C 12.011 S <NA> <NA> <NA> <NA> 1 <NA> SP3 False False False 0.104 0.598 0.758
28 608734 28 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 True True False 0.884 -0.416 0.136
29 608734 29 C 12.011 S <NA> <NA> <NA> <NA> 1 <NA> SP3 False False False -1.305 0.543 0.158
30 608734 30 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 True True False -1.778 -0.796 0.267
31 608734 31 C 12.011 R <NA> <NA> <NA> <NA> 1 <NA> SP3 False False False -2.310 1.475 0.859
32 608734 32 C 12.011 None <NA> <NA> <NA> <NA> <NA> <NA> SP3 False False False -2.416 1.172 2.377
33 608734 33 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -3.552 1.868 3.019
34 608734 34 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -3.328 2.882 3.973
35 608734 35 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -4.410 3.551 4.570
36 608734 36 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -5.726 3.207 4.222
37 608734 37 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -5.960 2.195 3.277
38 608734 38 C 12.011 None <NA> <NA> <NA> 1 <NA> <NA> SP2 False False True -4.877 1.528 2.679
39 608734 39 O 15.999 None <NA> <NA> <NA> <NA> <NA> <NA> SP2 True False False -1.147 4.771 1.350
Code
# remove ring features for speed (will be added via join later)
def build_interim_atom_record(mol, pse=None) -> np.ndarray:
    molecule_id = get_mol_id(mol)
    pse = Chem.GetPeriodicTable() if pse is None else pse
    atom_chirality_map = dict(Chem.FindMolChiralCenters(mol))
    mol_hbd_atom_indexs = get_hbd_atom_idxs(mol)
    mol_hba_atom_indexs = get_hba_atom_idxs(mol)

    dtypes = [
        ("molecule_id", "int"),
        ("index", "int"),
        ("symbol", "<U8"),
        ("weight", "float"),
        ("chirality", "O"),
        ("hybridization", "O"),
        ("acceptor", "bool"),
        ("donor", "bool"),
        ("aromatic", "bool"),
        ("x", "float"),
        ("y", "float"),
        ("z", "float"),
    ]

    values = []
    for index, atom in enumerate(mol.GetAtoms()):
        symbol = pse.GetElementSymbol(atom.GetAtomicNum())
        weight = pse.GetAtomicWeight(atom.GetAtomicNum())
        chirality = atom_chirality_map.get(atom.GetIdx(), None)
        rdk_hybridization = str(atom.GetHybridization())
        hybridization = (
            {"SP2D": "SP2", "SP3D": "SP3", "SP3D2": "SP3"}.get(rdk_hybridization, None)
            if rdk_hybridization not in ["SP", "SP2", "SP3"]
            else rdk_hybridization
        )
        acceptor = atom.GetIdx() in mol_hba_atom_indexs
        donor = atom.GetIdx() in mol_hbd_atom_indexs
        aromatic = atom.GetIsAromatic()
        positions = mol.GetConformer().GetAtomPosition(index)
        x = positions.x
        y = positions.y
        z = positions.z

        values.append(
            (
                molecule_id,
                index,
                symbol,
                weight,
                chirality,
                hybridization,
                acceptor,
                donor,
                aromatic,
                x,
                y,
                z,
            )
        )

    return np.array(values, dtype=dtypes)


def build_interim_atom_df(mol, pse=None):
    return pd.DataFrame(build_interim_atom_record(mol, pse=pse))
build_interim_atom_df(mol, pse=pse)
molecule_id index symbol weight chirality hybridization acceptor donor aromatic x y z
0 608734 0 O 15.999 None SP2 True False False -1.011 3.174 -6.577
1 608734 1 C 12.011 None SP2 False False False -2.049 3.469 -6.009
2 608734 2 O 15.999 None SP2 False True False -2.863 4.337 -6.577
3 608734 3 C 12.011 None SP3 False False False -2.393 2.867 -4.752
4 608734 4 C 12.011 None SP3 False False False -2.502 3.929 -3.645
5 608734 5 C 12.011 None SP3 False False False -2.880 3.267 -2.312
6 608734 6 C 12.011 None SP3 False False False -2.951 4.314 -1.190
7 608734 7 C 12.011 None SP3 False False False -3.229 3.648 0.169
8 608734 8 N 14.007 None SP2 False False False -2.103 2.881 0.635
9 608734 9 C 12.011 None SP2 False False False -0.996 3.554 1.092
10 608734 10 N 14.007 None SP2 False False False 0.254 3.034 1.322
11 608734 11 C 12.011 None SP3 False False False 1.139 3.741 2.209
12 608734 12 C 12.011 None SP3 False False False 0.643 3.635 3.660
13 608734 13 C 12.011 None SP3 False False False 1.642 4.286 4.627
14 608734 14 C 12.011 None SP3 False False False 1.136 4.164 6.072
15 608734 15 C 12.011 None SP3 False False False 2.124 4.825 7.045
16 608734 16 C 12.011 None SP2 False False False 1.667 4.702 8.401
17 608734 17 O 15.999 None SP2 False True False 0.526 5.250 8.771
18 608734 18 O 15.999 None SP2 True False False 2.330 4.101 9.230
19 608734 19 C 12.011 R SP3 False False False 0.808 1.949 0.558
20 608734 20 C 12.011 None SP3 False False False 0.928 2.329 -0.941
21 608734 21 C 12.011 None SP2 False False True 1.788 3.510 -1.166
22 608734 22 C 12.011 None SP2 False False True 3.167 3.452 -0.879
23 608734 23 C 12.011 None SP2 False False True 3.984 4.577 -1.080
24 608734 24 C 12.011 None SP2 False False True 3.429 5.770 -1.573
25 608734 25 C 12.011 None SP2 False False True 2.057 5.836 -1.867
26 608734 26 C 12.011 None SP2 False False True 1.241 4.710 -1.666
27 608734 27 C 12.011 S SP3 False False False 0.104 0.598 0.758
28 608734 28 O 15.999 None SP3 True True False 0.884 -0.416 0.136
29 608734 29 C 12.011 S SP3 False False False -1.305 0.543 0.158
30 608734 30 O 15.999 None SP3 True True False -1.778 -0.796 0.267
31 608734 31 C 12.011 R SP3 False False False -2.310 1.475 0.859
32 608734 32 C 12.011 None SP3 False False False -2.416 1.172 2.377
33 608734 33 C 12.011 None SP2 False False True -3.552 1.868 3.019
34 608734 34 C 12.011 None SP2 False False True -3.328 2.882 3.973
35 608734 35 C 12.011 None SP2 False False True -4.410 3.551 4.570
36 608734 36 C 12.011 None SP2 False False True -5.726 3.207 4.222
37 608734 37 C 12.011 None SP2 False False True -5.960 2.195 3.277
38 608734 38 C 12.011 None SP2 False False True -4.877 1.528 2.679
39 608734 39 O 15.999 None SP2 True False False -1.147 4.771 1.350

The interim version of the Atom record is 5 - 6x faster than the single pass version. Again, it’s unlikely vectorized creation of the ring_size_[]_count features will eat up the performance gap between the interim and single pass versions used to build the Atom records.

%%timeit
build_interim_atom_record(mol, pse=pse)
314 µs ± 1.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%%timeit
build_atom_record(mol, pse=pse)
1.79 ms ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Bond

def in_same_ring(mol, atom1_index, atom2_index):
    rings = mol.GetRingInfo().AtomRings()
    for ring_atoms in rings:
        if (atom1_index in ring_atoms) and (atom2_index in ring_atoms):
            return True

    return False
Code
from typing import Optional


def build_bond_record(mol) -> np.ndarray:
    molecule_id = get_mol_id(mol)

    dtypes = [
        ("molecule_id", "int"),
        ("index", "int"),
        ("atom1_index", "int"),
        ("atom2_index", "int"),
        ("type", "O"),
        ("stereochemistry", "O"),
        ("same_ring", "bool"),
    ]

    values = []
    for bond in mol.GetBonds():
        bond_index = bond.GetIdx()
        atom1_index = bond.GetBeginAtomIdx()
        atom2_index = bond.GetEndAtomIdx()

        raw_bond_type = str(bond.GetBondType()).lower()
        bond_type = (
            raw_bond_type if raw_bond_type in ["single", "double", "triple", "aromatic"] else None
        )

        raw_stereochemistry = str(bond.GetStereo()).lower().replace("stereo", "")
        stereochemistry = {"none": None}.get(raw_stereochemistry, raw_stereochemistry)

        same_ring = in_same_ring(mol, atom1_index, atom2_index)

        values.append(
            (
                molecule_id,
                bond_index,
                atom1_index,
                atom2_index,
                bond_type,
                stereochemistry,
                same_ring,
            )
        )

    return np.array(values, dtype=dtypes)


def build_bond_df(mol) -> pd.DataFrame:
    return pd.DataFrame(build_bond_record(mol))
bond_df = build_bond_df(mol)
bond_df
molecule_id index atom1_index atom2_index type stereochemistry same_ring
0 608734 0 0 1 double None False
1 608734 1 1 2 single None False
2 608734 2 1 3 single None False
3 608734 3 3 4 single None False
4 608734 4 4 5 single None False
5 608734 5 5 6 single None False
6 608734 6 6 7 single None False
7 608734 7 7 8 single None False
8 608734 8 8 9 single None True
9 608734 9 8 31 single None True
10 608734 10 9 10 single None True
11 608734 11 9 39 double None False
12 608734 12 10 11 single None False
13 608734 13 10 19 single None True
14 608734 14 11 12 single None False
15 608734 15 12 13 single None False
16 608734 16 13 14 single None False
17 608734 17 14 15 single None False
18 608734 18 15 16 single None False
19 608734 19 16 17 single None False
20 608734 20 16 18 double None False
21 608734 21 19 20 single None False
22 608734 22 19 27 single None True
23 608734 23 20 21 single None False
24 608734 24 21 22 aromatic None True
25 608734 25 21 26 aromatic None True
26 608734 26 22 23 aromatic None True
27 608734 27 23 24 aromatic None True
28 608734 28 24 25 aromatic None True
29 608734 29 25 26 aromatic None True
30 608734 30 27 28 single None False
31 608734 31 27 29 single None True
32 608734 32 29 30 single None False
33 608734 33 29 31 single None True
34 608734 34 31 32 single None False
35 608734 35 32 33 single None False
36 608734 36 33 34 aromatic None True
37 608734 37 33 38 aromatic None True
38 608734 38 34 35 aromatic None True
39 608734 39 35 36 aromatic None True
40 608734 40 36 37 aromatic None True
41 608734 41 37 38 aromatic None True
Code
# remove ring features for speed (will be added in later with a join)
def build_interim_bond_record(mol) -> np.ndarray:
    molecule_id = get_mol_id(mol)

    dtypes = [
        ("molecule_id", "int"),
        ("index", "int"),
        ("atom1_index", "int"),
        ("atom2_index", "int"),
        ("type", "O"),
        ("stereochemistry", "O"),
    ]

    values = []
    for bond in mol.GetBonds():
        index = bond.GetIdx()
        atom1_index = bond.GetBeginAtomIdx()
        atom2_index = bond.GetEndAtomIdx()

        raw_bond_type = str(bond.GetBondType()).lower()
        type = (
            raw_bond_type if raw_bond_type in ["single", "double", "triple", "aromatic"] else None
        )

        raw_stereochemistry = str(bond.GetStereo()).lower().replace("stereo", "")
        stereochemistry = {"none": None}.get(raw_stereochemistry, raw_stereochemistry)

        values.append(
            (
                molecule_id,
                index,
                atom1_index,
                atom2_index,
                type,
                stereochemistry,
            )
        )

    return np.array(values, dtype=dtypes)


def build_interim_bond_df(mol) -> pd.DataFrame:
    return pd.DataFrame(build_interim_bond_record(mol))
build_interim_bond_df(mol)
molecule_id index atom1_index atom2_index type stereochemistry
0 608734 0 0 1 double None
1 608734 1 1 2 single None
2 608734 2 1 3 single None
3 608734 3 3 4 single None
4 608734 4 4 5 single None
5 608734 5 5 6 single None
6 608734 6 6 7 single None
7 608734 7 7 8 single None
8 608734 8 8 9 single None
9 608734 9 8 31 single None
10 608734 10 9 10 single None
11 608734 11 9 39 double None
12 608734 12 10 11 single None
13 608734 13 10 19 single None
14 608734 14 11 12 single None
15 608734 15 12 13 single None
16 608734 16 13 14 single None
17 608734 17 14 15 single None
18 608734 18 15 16 single None
19 608734 19 16 17 single None
20 608734 20 16 18 double None
21 608734 21 19 20 single None
22 608734 22 19 27 single None
23 608734 23 20 21 single None
24 608734 24 21 22 aromatic None
25 608734 25 21 26 aromatic None
26 608734 26 22 23 aromatic None
27 608734 27 23 24 aromatic None
28 608734 28 24 25 aromatic None
29 608734 29 25 26 aromatic None
30 608734 30 27 28 single None
31 608734 31 27 29 single None
32 608734 32 29 30 single None
33 608734 33 29 31 single None
34 608734 34 31 32 single None
35 608734 35 32 33 single None
36 608734 36 33 34 aromatic None
37 608734 37 33 38 aromatic None
38 608734 38 34 35 aromatic None
39 608734 39 35 36 aromatic None
40 608734 40 36 37 aromatic None
41 608734 41 37 38 aromatic None

The interim Bond approach is 2x faster than the single-pass approach.

%%timeit
build_interim_bond_record(mol)
120 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%%timeit
build_bond_record(mol)
59.7 ms ± 657 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Performance Comparison

The interim processing approach is 30x faster than generating all features immediately. As mentioned earlier, this isn’t a perfect comparison since the features the interim tables lack need to be built eventually. Generating these features for all records in a single vectorized query will take some amount of time but is unlikely to eat the entire lead seen here. Minimal feature engineering in the molecule ingestion loop followed by vectorized feature engineering is likely to be faster than generating all features within the molecule ingestion loop.

pse = Chem.GetPeriodicTable()
%%timeit
build_interim_mol_record(mol, supplier_index=0)
build_ring_record(mol)
build_interim_atom_record(mol, pse=pse)
build_interim_bond_record(mol)
2.01 ms ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
build_mol_record(mol, supplier_index=0)
build_atom_record(mol, pse=pse)
build_bond_record(mol)
61 ms ± 126 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Naive Processing

from rdkit import RDLogger
pse = Chem.GetPeriodicTable()

mol_records = []
ring_records = []
atom_records = []
bond_records = []
max_iter = 10
RDLogger.DisableLog("rdApp.*")
for i, mol in enumerate(supplier):
    if i >= max_iter:
        break
    mol_records.append(build_interim_mol_record(mol, supplier_index=i))
    ring_records.append(build_ring_record(mol))
    atom_records.append(build_interim_atom_record(mol, pse=pse))
    bond_records.append(build_interim_bond_record(mol))
RDLogger.EnableLog("rdApp.*")
mol_df = pd.DataFrame(np.concatenate(mol_records))
ring_df = pd.DataFrame(np.concatenate(ring_records))
atom_df = pd.DataFrame(np.concatenate(atom_records))
bond_df = pd.DataFrame(np.concatenate(bond_records))
mol_df
id supplier_index name smiles n_atoms n_bonds weight
0 608734 0 6-[(4R,5S,6S,7R)-4,7-dibenzyl-3-(5-carboxypent... O=C(O)CCCCCN1C(=O)N(CCCCCC(=O)O)[C@H](Cc2ccccc... 40 42 554.299202
1 22 1 (4R,5S,6S,7R)-4,7-dibenzyl-5,6-dihydroxy-1,3-b... O=C1N(C/C=C/c2cn[nH]c2)[C@H](Cc2ccccc2)[C@H](O... 40 44 538.269239
2 23 2 (4R,5S,6S,7R)-4,7-dibenzyl-1-(cyclopropylmethy... O=C1N(C/C=C/c2cn[nH]c2)[C@H](Cc2ccccc2)[C@H](O... 36 40 486.263091
3 24 3 (4R,5S,6S,7R)-4,7-dibenzyl-1-(cyclopropylmethy... O=C1N(CCCCCCO)[C@H](Cc2ccccc2)[C@H](O)[C@@H](O... 35 38 480.298808
4 25 4 (4R,5S,6S,7R)-4,7-dibenzyl-1-(cyclopropylmethy... O=C1N(CCCCCO)[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)... 34 37 466.283158
5 26 5 (4R,5S,6S,7R)-4,7-dibenzyl-1-butyl-3-(cyclopro... CCCCN1C(=O)N(CC2CC2)[C@H](Cc2ccccc2)[C@H](O)[C... 32 35 436.272593
6 27 6 (4R,5S,6S,7R)-4,7-dibenzyl-1,3-bis(cyclobutylm... O=C1N(CC2CCC2)[C@H](Cc2ccccc2)[C@H](O)[C@@H](O... 34 38 462.288243
7 28 7 (4R,5S,6S,7R)-4,7-dibenzyl-5,6-dihydroxy-1,3-b... O=C1N(CCCCCO)[C@H](Cc2ccccc2)[C@H](O)[C@@H](O)... 36 38 498.309372
8 29 8 (4R,5S,6S,7R)-4,7-dibenzyl-1,3-dibutyl-5,6-dih... CCCCN1C(=O)N(CCCC)[C@H](Cc2ccccc2)[C@H](O)[C@@... 32 34 438.288243
9 30 9 (4R,5S,6S,7R)-4,7-dibenzyl-5,6-dihydroxy-1,3-b... CC(C)=CCN1C(=O)N(CC=C(C)C)[C@H](Cc2ccccc2)[C@H... 34 36 462.288243
ring_df
molecule_id index size atom_indices
0 608734 0 7 (8, 9, 10, 19, 27, 29, 31)
1 608734 1 6 (22, 23, 24, 25, 26, 21)
2 608734 2 6 (34, 35, 36, 37, 38, 33)
3 22 0 7 (0, 1, 3, 5, 6, 4, 2)
4 22 1 5 (7, 11, 16, 10, 23)
5 22 2 5 (8, 12, 17, 9, 22)
6 22 3 6 (30, 37, 39, 34, 31, 29)
7 22 4 6 (32, 35, 38, 36, 33, 28)
8 23 0 7 (1, 0, 3, 5, 6, 4, 2)
9 23 1 5 (7, 10, 17, 9, 20)
10 23 2 3 (15, 12, 16)
11 23 3 6 (26, 33, 35, 30, 27, 25)
12 23 4 6 (28, 31, 34, 32, 29, 24)
13 24 0 7 (1, 0, 3, 5, 6, 4, 2)
14 24 1 3 (12, 9, 13)
15 24 2 6 (21, 32, 34, 29, 22, 18)
16 24 3 6 (23, 31, 33, 30, 24, 17)
17 25 0 7 (1, 0, 3, 5, 6, 4, 2)
18 25 1 3 (12, 9, 13)
19 25 2 6 (21, 28, 33, 31, 24, 17)
20 25 3 6 (22, 29, 32, 30, 23, 18)
21 26 0 7 (1, 0, 3, 5, 6, 4, 2)
22 26 1 3 (12, 9, 13)
23 26 2 6 (19, 27, 30, 26, 20, 17)
24 26 3 6 (21, 28, 31, 29, 22, 18)
25 27 0 7 (0, 1, 3, 6, 5, 4, 2)
26 27 1 4 (18, 23, 17, 20)
27 27 2 4 (19, 22, 16, 21)
28 27 3 6 (24, 29, 32, 28, 27, 14)
29 27 4 6 (25, 30, 33, 31, 26, 15)
30 28 0 7 (0, 1, 3, 5, 6, 4, 2)
31 28 1 6 (20, 31, 35, 30, 21, 15)
32 28 2 6 (22, 32, 34, 33, 23, 14)
33 29 0 7 (0, 1, 3, 5, 6, 4, 2)
34 29 1 6 (16, 27, 30, 28, 17, 14)
35 29 2 6 (18, 29, 31, 26, 19, 15)
36 30 0 7 (0, 1, 3, 5, 6, 4, 2)
37 30 1 6 (24, 30, 33, 31, 25, 18)
38 30 2 6 (26, 29, 32, 28, 27, 19)
atom_df
molecule_id index symbol weight chirality hybridization acceptor donor aromatic x y z
0 608734 0 O 15.999 None SP2 True False False -1.011 3.174 -6.577
1 608734 1 C 12.011 None SP2 False False False -2.049 3.469 -6.009
2 608734 2 O 15.999 None SP2 False True False -2.863 4.337 -6.577
3 608734 3 C 12.011 None SP3 False False False -2.393 2.867 -4.752
4 608734 4 C 12.011 None SP3 False False False -2.502 3.929 -3.645
... ... ... ... ... ... ... ... ... ... ... ... ...
348 30 29 C 12.011 None SP2 False False True -1.060 -2.547 3.635
349 30 30 C 12.011 None SP2 False False True -0.270 5.992 -1.587
350 30 31 C 12.011 None SP2 False False True -1.897 5.931 0.222
351 30 32 C 12.011 None SP2 False False True -2.163 -3.363 3.942
352 30 33 C 12.011 None SP2 False False True -0.866 6.594 -0.466

353 rows × 12 columns

bond_df
molecule_id index atom1_index atom2_index type stereochemistry
0 608734 0 0 1 double None
1 608734 1 1 2 single None
2 608734 2 1 3 single None
3 608734 3 3 4 single None
4 608734 4 4 5 single None
... ... ... ... ... ... ...
377 30 31 27 28 aromatic None
378 30 32 28 32 aromatic None
379 30 33 29 32 aromatic None
380 30 34 30 33 aromatic None
381 30 35 31 33 aromatic None

382 rows × 6 columns

Batching

Writes will be batched to prevent running an IO operation for each molecule.

from typing import Iterable


def batched_supplier(
    supplier: Chem.SDMolSupplier, size: int = 1, yield_index: bool = True
) -> Iterable:
    next_i = 0
    end = False
    while not end:
        batch = []
        indices = []
        for i in range(next_i, next_i + size):
            try:
                batch.append(supplier[i])
                indices.append(i)
            except IndexError:
                end = True

        next_i += size
        if yield_index:
            yield tuple(zip(indices, batch))
        else:
            yield batch
from joblib import Parallel, delayed
from rdkit import RDLogger
from typing import Union, List
import shutil
from pyarrow import ArrowInvalid
import glob


def process_mol(
    mol, supplier_index: Optional[int] = None, pse=None
) -> Tuple[np.array, np.array, np.array, np.array]:
    return (
        build_interim_mol_record(mol, supplier_index=supplier_index),
        build_ring_record(mol),
        build_interim_atom_record(mol, pse=pse),
        build_interim_bond_record(mol),
    )


def process_batch(
    batch,
    index: int,
    data_path: Union[str, Path],
    pse=None,
    multiprocessing: bool = False,
):
    pse = Chem.GetPeriodicTable() if pse is None else pse

    # This is hack.
    # Chem.Mol objects can be pickled but they weren't being "materialized"
    # properly due to the lazy evaluation of `SDMolSupplier`.
    # `SDMolSupplier` cannot be pickled. As a workaround, we instantiate a new
    # supplier and read the Chem.Mol object in each worker.
    # This is slower than pickling "materialized" `Chem.Mol` objects
    # (evaluating them before passing to the workers), but still faster than sequential execution.
    if multiprocessing:
        supplier = Chem.SDMolSupplier(bdb_path)

    # any processed molecule parquet files exist
    if len(glob.glob(str(data_path / "mol/*.parquet"))):
        try:
            processed_ids = pd.read_parquet(data_path / "mol/", columns=["id"]).values
        # can occur when file exists but is still empty (race condition)
        except (ArrowInvalid, OSError):
            processed_ids = np.array([])
    else:
        processed_ids = np.array([])

    def _is_processed(mol) -> Optional[bool]:
        try:
            return get_mol_id(mol) in processed_ids
        except:
            return None

    # > Generate Records
    mol_records = []
    ring_records = []
    atom_records = []
    bond_records = []
    unprocessed = []
    for supplier_index, mol in batch:
        if multiprocessing:
            RDLogger.DisableLog("rdApp.*")
            mol = supplier[supplier_index]
            RDLogger.EnableLog("rdApp.*")

        if (mol is not None) and (not _is_processed(mol)):
            try:
                mol_record, ring_record, atom_record, bond_record = process_mol(
                    mol, supplier_index, pse=pse
                )
                mol_records.append(mol_record)
                ring_records.append(ring_record)
                atom_records.append(atom_record)
                bond_records.append(bond_record)

                # prevent reprocessing within batch
                # most repeated IDs are co-located in the supplier
                processed_ids = np.append(processed_ids, get_mol_id(mol))

            except Exception as e:
                unprocessed.append((supplier_index, str(e.__class__.__name__), str(e)))
    # < Generate Records

    # > Write Records
    if len(mol_records):
        mol_df = pd.DataFrame(np.concatenate(mol_records))
        mol_df.to_parquet(data_path / f"mol/{index}.parquet")

    if len(ring_records):
        ring_df = pd.DataFrame(np.concatenate(ring_records))
        ring_df.to_parquet(data_path / f"ring/{index}.parquet")

    if len(atom_records):
        atom_df = pd.DataFrame(np.concatenate(atom_records))
        atom_df.to_parquet(data_path / f"atom/{index}.parquet")

    if len(bond_records):
        bond_df = pd.DataFrame(np.concatenate(bond_records))
        bond_df.to_parquet(data_path / f"bond/{index}.parquet")

    if len(unprocessed):
        print(f"{len(unprocessed)} unprocessed molecules in batch {index}.")
        unprocessed_df = pd.DataFrame(
            columns=["supplier_index", "exception_class", "exception_message"], data=unprocessed
        )
        unprocessed_df.to_csv(data_path / f"unprocessed/{index}.csv")
    # < Write Records

Processing Tests

Multiprocessing

Not using RDKit Chem.MultithreadedSDMolSupplier since it would complicate batched file writes, capturing the supplier index, and isn’t notebook friendly.

from os import makedirs
from dti.config import data_path

drug_system_test_path = data_path / f"interim/drug_system/test"
makedirs(drug_system_test_path, exist_ok=True)

makedirs(drug_system_test_path / "mol", exist_ok=True)
makedirs(drug_system_test_path / "ring", exist_ok=True)
makedirs(drug_system_test_path / "atom", exist_ok=True)
makedirs(drug_system_test_path / "bond", exist_ok=True)
makedirs(drug_system_test_path / "unprocessed", exist_ok=True)

Multithreading succeeds but isn’t consistently faster than sequential execution.

%%timeit -n 2 -r 3
RDLogger.DisableLog("rdApp.*")
pse = Chem.GetPeriodicTable()
max_batches = 50
for batch_index, batch in enumerate(batched_supplier(supplier, size=1000)):
    if batch_index >= max_batches:
        break
    _ = process_batch(batch, batch_index, drug_system_test_path, pse=pse)
RDLogger.EnableLog("rdApp.*")
42.6 s ± 158 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)
%%timeit -n 2 -r 3
RDLogger.DisableLog("rdApp.*")
proc = Parallel(n_jobs=-1, prefer="threads", verbose=0)

tasks = []
max_batches = 50
for batch_index, batch in enumerate(batched_supplier(supplier, size=1000)):
    if batch_index >= max_batches:
        break
    tasks.append(delayed(process_batch)(batch, index, drug_system_test_path))


_ = proc(tasks)
RDLogger.EnableLog("rdApp.*")
59.6 s ± 114 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)

Multiprocessing fails to process the molecules since the Chem.SDMolSupplier is lazily evaluated.

RDLogger.DisableLog("rdApp.*")
proc = Parallel(n_jobs=-1, prefer="processes", verbose=0)

tasks = []
max_batches = 1
for batch_index, batch in enumerate(batched_supplier(supplier, size=1)):
    if batch_index >= max_batches:
        break
    tasks.append(delayed(process_batch)(batch, index, drug_system_test_path))


_ = proc(tasks)
RDLogger.EnableLog("rdApp.*")
1 unprocessed molecules in batch 0.

As a workaround to the lazy evaluation proplem, we instantiate a new supplier and read the Chem.Mol object in each worker. This is likely slower than pickling “materialized” Chem.Mol objects (evaluating them before passing to the workers), but still faster than sequential execution.

# def process_batch()

if multiprocessing:
    supplier = Chem.SDMolSupplier(bdb_path)

for supplier_index, mol in batch:
    if multiprocessing:
        RDLogger.DisableLog("rdApp.*")
        mol = supplier[supplier_index]
        RDLogger.EnableLog("rdApp.*")

Multiprocessing is faster than multi-threading for larger batch sizes.

%%timeit -n 2 -r 3
RDLogger.DisableLog("rdApp.*")
proc = Parallel(n_jobs=-1, prefer="processes", verbose=0)

tasks = []
max_batches = 5
for batch_index, batch in enumerate(batched_supplier(supplier, size=10000)):
    if batch_index >= max_batches:
        break
    tasks.append(delayed(process_batch)(batch, index, drug_system_test_path, multiprocessing=True))


_ = proc(tasks)
RDLogger.EnableLog("rdApp.*")
28.5 s ± 417 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)
%%timeit -n 2 -r 3
RDLogger.DisableLog("rdApp.*")
proc = Parallel(n_jobs=-1, prefer="threads", verbose=0)

tasks = []
max_batches = 5
for batch_index, batch in enumerate(batched_supplier(supplier, size=10000)):
    if batch_index >= max_batches:
        break
    tasks.append(delayed(process_batch)(batch, index, drug_system_test_path))


_ = proc(tasks)
RDLogger.EnableLog("rdApp.*")
49.2 s ± 164 ms per loop (mean ± std. dev. of 3 runs, 2 loops each)

Reprocessing

Some reprocessing will occur due to race conditions between multiple workers, but process_batch aims to prevent reprocessing molecules that have already been written to file.

Testing this functionality:

import glob

makedirs(drug_system_test_path, exist_ok=True)

makedirs(drug_system_test_path / "mol", exist_ok=True)
makedirs(drug_system_test_path / "ring", exist_ok=True)
makedirs(drug_system_test_path / "atom", exist_ok=True)
makedirs(drug_system_test_path / "bond", exist_ok=True)
makedirs(drug_system_test_path / "unprocessed", exist_ok=True)

RDLogger.DisableLog("rdApp.*")
pse = Chem.GetPeriodicTable()
max_batches = 5
for batch_index, batch in enumerate([[(0, supplier[0])], [(0, supplier[0])]]):
    if batch_index >= max_batches:
        break
    _ = process_batch(batch, batch_index, drug_system_test_path, pse=pse)
RDLogger.EnableLog("rdApp.*")

assert len(glob.glob(str(drug_system_test_path / "mol/*"))) == 1

Process

(Runtime: ~1.5 hours.)

from datetime import datetime

process_datetime_str = datetime.now().strftime("%Y%m%dT%H%M%S")
drug_system_path = data_path / f"interim/drug_system/{process_datetime_str}"
makedirs(drug_system_path)

makedirs(drug_system_path / "mol")
makedirs(drug_system_path / "ring")
makedirs(drug_system_path / "atom")
makedirs(drug_system_path / "bond")
makedirs(drug_system_path / "unprocessed")

RDLogger.DisableLog("rdApp.*")
with Parallel(n_jobs=-1, verbose=100) as parallel:
    parallel(
        delayed(process_batch)(batch, index, drug_system_path, multiprocessing=True)
        for index, batch in enumerate(batched_supplier(supplier, size=10000))
    )
RDLogger.EnableLog("rdApp.*")