Datasets, Dataloaders and Training
You can find this notebook on Google Colab here
In this second block we’ll take a deep dive into graph neural networks. This block contains a lot of technical steps in how to go from a cleanly described thing (a graph representation of molecules) to training a neural network.
We’ll go over the full pipeline from raw data to a trained GNN with a focus on the data preparation part, and we’ll gloss over the details about the GNN which will go into in session 3.
List of contents
Data pipeline: Raw data to graph representations (copied from block 1)
The dataset class
Data loaders
The actual dataset: load, split & build
Implementing the GNN
Training the model
Getting predictions
The Graph Data
First, we’ll prepare by re-instantiating the functions and classes introduced in the first session.
In this example, we’ll look at creating simple PyTorch datasets which can be used to train our model. We will use the graph representations we defined in the previous notebook. We’ll also look a bit about data loaders and the collate function and how we ultimately want to represent the batches for our neural network.
import rdkit
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG = True # < use SVGs instead of PNGs
IPythonConsole.drawOptions.addAtomIndices = True # adding indices
IPythonConsole.molSize = 300, 300
import torch
from torch.utils.data import Dataset, DataLoader
from collections.abc import Set
float_type = torch.float32 # Hardcoding datatypes for tensors
categorical_type = torch.long
mask_type = torch.float32 # Hardcoding datatype for masks
labels_type = torch.float32 # Hardcoding datatype for labels
class ContinuousFeature:
def __init__(self, name):
self.name = name
def __repr__(self):
return f'<ContinuousFeature: {self.name}>'
def __eq__(self, other): # Q to Erik: can we leave this out?
return self.name == other.name
def __hash__(self):
return hash(self.name)
class CategoricalFeature:
def __init__(self, name, values, add_null_value=True):
self.name = name
self.has_null_value = add_null_value
if self.has_null_value:
self.null_value = None
values = (None,) + tuple(values)
self.values = tuple(values)
self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
self.inv_value_to_idx_mapping = {i: v for v, i in
self.value_to_idx_mapping.items()}
if self.has_null_value:
self.null_value_idx = self.value_to_idx_mapping[self.null_value]
def get_null_idx(self): # Q to Erik: can we leave this out?
if self.has_null_value:
return self.null_value_idx
else:
raise RuntimeError(f"Categorical variable {self.name} has no null value")
def value_to_idx(self, value):
return self.value_to_idx_mapping[value]
def idx_to_value(self, idx):
return self.inv_value_to_idx_mapping[idx]
def __len__(self):
return len(self.values)
def __repr__(self):
return f'<CategoricalFeature: {self.name}>'
def __eq__(self, other): # Q to Erik: can we leave this out?
return self.name == other.name and self.values == other.values
def __hash__(self):
return hash((self.name, self.values))
# Atom types
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti',
'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As',
'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru',
'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs',
'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl',
'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg',
'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce',
'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er',
'Tm', 'Yb', 'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm',
'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']
SYMBOLS_FEATURE = CategoricalFeature('atom_symbol', ATOM_SYMBOLS)
# Aromaticity
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalFeature('is_aromatic', AROMATIC_VALUES)
# Explicit valance
EXPLICIT_VALANCE_FEATURE = ContinuousFeature('explicit_valance')
# Implicit valance
IMPLICIT_VALANCE_FEATURE = ContinuousFeature('implicit_valance')
# Combine all four into one list of features
ATOM_FEATURES = [SYMBOLS_FEATURE,
AROMATIC_FEATURE,
EXPLICIT_VALANCE_FEATURE,
IMPLICIT_VALANCE_FEATURE]
# Bond types
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE',
'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC',
'IONIC', 'HYDROGEN', 'THREECENTER', 'DATIVEONE', 'DATIVE',
'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']
TYPE_FEATURE = CategoricalFeature('bond_type', BOND_TYPES)
# Bond directions
BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT',
'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalFeature('bond_direction', BOND_DIRECTIONS)
# Bond, James Bond
BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE',
'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalFeature('bond_stereo', BOND_STEREO)
# Aromaticity
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalFeature('is_aromatic', AROMATIC_VALUES)
# Combine all four into one list of features
BOND_FEATURES = [TYPE_FEATURE,
DIRECTION_FEATURE,
AROMATIC_FEATURE,
STEREO_FEATURE]
# Atom features
def get_atom_features(rd_atom):
atom_symbol = rd_atom.GetSymbol()
is_aromatic = rd_atom.GetIsAromatic()
implicit_valance = float(rd_atom.GetImplicitValence())
explicit_valance = float(rd_atom.GetExplicitValence())
return {SYMBOLS_FEATURE: atom_symbol,
AROMATIC_FEATURE: is_aromatic,
EXPLICIT_VALANCE_FEATURE: explicit_valance,
IMPLICIT_VALANCE_FEATURE: implicit_valance}
# Bond features
def get_bond_features(rd_bond):
bond_type = str(rd_bond.GetBondType())
bond_stereo_info = str(rd_bond.GetStereo())
bond_direction = str(rd_bond.GetBondDir())
is_aromatic = rd_bond.GetIsAromatic()
return {TYPE_FEATURE: bond_type,
DIRECTION_FEATURE: bond_direction,
AROMATIC_FEATURE: is_aromatic,
STEREO_FEATURE: bond_stereo_info}
# Create dictionaries of the atoms and bonds in a molecule
def rdmol_to_graph(rd_mol):
atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom)
for rd_atom in rd_mol.GetAtoms()}
bonds = {frozenset((rd_bond.GetBeginAtomIdx(), rd_bond.GetEndAtomIdx())):
get_bond_features(rd_bond) for rd_bond in rd_mol.GetBonds()}
return atoms, bonds
def smiles_to_graph(smiles):
rd_mol = MolFromSmiles(smiles)
graph = rdmol_to_graph(rd_mol)
return graph
The Dataset class
The goal is to create representations of our graphs which are suited as inputs to a Graph Neural Network and Transformers. This means defining a class which can be used to create minibatches. We’ll use a Mapped dataset, which means that it will have a __getitem__
method and a __len__
method. This makes the class follow pythons iterator protocols and allows pytorch’s DataLoader
to use it.
A lot of the complexity comes from the fact that we want to be able to create mini-batches of our graph data, but the graphs are almost all different (differing number of nodes, different number of edges), while the multidimensionall arrays (ndarrays, tensors) we use as inputs to our GNN must be even in shape.
Masking
To handle this difference in shape, we’ll make use of masking, for each tensor which contains multiple graph properties stacked along an axis, we provide an additional mask tensor which shows us where there is relevant information. This mask will be used inside the neural network to zero-out any irrelevant results of computations.
Defining the dataset
We define our dataset class as a PyTorch Dataset, with init arguments being a list of graphs (our molecules) and a equal length list of labels per graph.
Additionally we supply a list of atom features and bond features. The reason for supplying the latter is that we can dynamically choose to limit the features our dataset will use without having to recompute the graph representations for our whole dataset.
We also support adding a list of metadata objects, arbitrary python objects which will be supplied together with the graph. This can be useful to give information about what row in a CSV the molecule was defined on or what the original SMILES was.
Continuous vs. Categorical features
From the neural network perspective, there’s an important distiction between continous and categorical features, so our Dataset needs to return items which makes it clear which is which.
The __getitem__
method
The most important part of the dataset is the __getitem__
method. It’s task is to take a graph from our dataset (chosen by the argument to the method) and create the tensor representation of this graph.
These tensor representations will be complex, we need to keep track of all 4 feature types (categorical vs. continuous, edges vs. nodes) independently, so we need to return at least 4 different tensors. Additionally, the GNN will need to know the structure of the graph, so we have to return tensors for that as well. In this example we use the adjacency matrix as well as a list of nodes to encode the graph structure.
All in all, the __getitem__
method will return these values:
List of nodes
Adjacency matrix
Categorical node features
Continuous node features,
Categorical edge features and
Continuous edge features
[Optional] Metadata entry
The primary work of this method is to collect the node features into tensors like illustrated in the figures below
Stacking node features
Stacking edge features
What to return?
When we return these kind of complex values from a method, just using a sequence (tuple) is prone to error. If we change how many values we return or their order, downstreams code will break. A common choice is to define a lightweight class (using __slots__
) or named tuple for this.
In our case, to make the code more accessible we’ll use a dictionary instead. The downside to this in practice is that the user needs to know the dictionary structure and can’t get help with this from an IDE, but the upside is that it’s easy to inspect in a notebook. There’s also computational overhead in using a dictionary which another datastrucure could avoid.
class GraphDataset(Dataset):
def __init__(self, *, graphs, labels, node_variables, edge_variables,
metadata=None):
'''
Create a new graph dataset,
'''
self.graphs = graphs
self.labels = labels
assert len(self.graphs) == len(self.labels), \
"The graphs and labels lists must be the same length"
self.metadata = metadata
if self.metadata is not None:
assert len(self.metadata) == len(self.graphs),\
"The metadata list needs to be as long as the graphs"
self.node_variables = node_variables
self.edge_variables = edge_variables
self.categorical_node_variables = [var for var in self.node_variables
if isinstance(var, CategoricalFeature)]
self.continuous_node_variables = [var for var in self.node_variables
if isinstance(var, ContinuousFeature)]
self.categorical_edge_variables = [var for var in self.edge_variables
if isinstance(var, CategoricalFeature)]
self.continuous_edge_variables = [var for var in self.edge_variables
if isinstance(var, ContinuousFeature)]
def __len__(self):
return len(self.graphs)
def make_continuous_node_features(self, nodes):
if len(self.continuous_node_variables) == 0:
return None
n_nodes = len(nodes)
n_features = len(self.continuous_node_variables)
continuous_node_features = torch.zeros((n_nodes, n_features),
dtype=float_type)
for node_idx, features in nodes.items():
node_features = torch.tensor([features[continuous_feature]
for continuous_feature
in self.continuous_node_variables],
dtype=float_type)
continuous_node_features[node_idx] = node_features
return continuous_node_features
def make_categorical_node_features(self, nodes):
if len(self.categorical_node_variables) == 0:
return None
n_nodes = len(nodes)
n_features = len(self.categorical_node_variables)
categorical_node_features = torch.zeros((n_nodes, n_features),
dtype=categorical_type)
for node_idx, features in nodes.items():
for i, categorical_variable in enumerate(self.categorical_node_variables):
value = features[categorical_variable]
value_index = categorical_variable.value_to_idx(value)
categorical_node_features[node_idx, i] = value_index
return categorical_node_features
def make_continuous_edge_features(self, n_nodes, edges):
if len(self.continuous_edge_variables) == 0:
return None
n_features = len(self.continuous_edge_variables)
continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features),
dtype=float_type)
for edge, features in edges.items():
edge_features = torch.tensor([features[continuous_feature]
for continuous_feature in
self.continuous_edge_variables],
dtype=float_type)
u,v = edge
continuous_edge_features[u, v] = edge_features
if isinstance(edge, Set):
continuous_edge_features[v, u] = edge_features
return continuous_edge_features
def make_categorical_edge_features(self, n_nodes, edges):
if len(self.categorical_edge_variables) == 0:
return None
n_features = len(self.categorical_edge_variables)
categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features),
dtype=categorical_type)
for edge, features in edges.items():
u,v = edge
for i, categorical_variable in enumerate(self.categorical_edge_variables):
value = features[categorical_variable]
value_index = categorical_variable.value_to_idx(value)
categorical_edge_features[u, v, i] = value_index
if isinstance(edge, Set):
categorical_edge_features[v, u, i] = value_index
return categorical_edge_features
def __getitem__(self, index):
# This is where the important stuff happens. We use our node and
# edge variable attributes to select what node and edge features to use.
# In practice, we often do this as a pre-processing step, but here we do it
# in the getitem function for clarity
graph = self.graphs[index]
nodes, edges = graph
n_nodes = len(nodes)
continuous_node_features = self.make_continuous_node_features(nodes)
categorical_node_features = self.make_categorical_node_features(nodes)
continuous_edge_features = self.make_continuous_edge_features(n_nodes,
edges)
categorical_edge_features = self.make_categorical_edge_features(n_nodes,
edges)
label = self.labels[index]
nodes_idx = sorted(nodes.keys())
edge_list = sorted(edges.keys())
n_nodes = len(nodes)
adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
for edge in edges:
u, v = edge
adjacency_matrix[u,v] = 1
if isinstance(edge, Set):
# This edge is unordered, assume this is a undirected graph
adjacency_matrix[v,u] = 1
data_record = {'nodes': nodes_idx,
'adjacency_matrix': adjacency_matrix,
'categorical_node_features': categorical_node_features,
'continuous_node_features': continuous_node_features,
'categorical_edge_features': categorical_edge_features,
'continuous_edge_features': continuous_edge_features,
'label': label}
# If you need to add extra information (metadata about this graph) you can
# add an extra key-value pair here. The advantage of using a dict compared
# to a tuple is that the downstreams code doesn't break as long as at least
# the expected keys are present. The downside is that using a dict adds
# overhead (accessing a dict compared to unpacking a tuple).
# A more robust implementation might actually make a separate class for
# dataset entires
if self.metadata is not None:
data_record['metadata'] = self.metadata[index]
return data_record
def get_node_variables(self):
return {'continuous': self.continuous_node_variables,
'categorical': self.categorical_node_variables}
def get_edge_variables(self):
return {'continuous': self.continuous_edge_variables,
'categorical': self.categorical_edge_variables}
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES,
bond_features=BOND_FEATURES):
'''
Create a new GraphDataset from a list of smiles_records dictionaries.
These records should contain the key 'smiles' and 'label'. Any other keys
will be saved as a 'metadata' record.
'''
graphs = []
labels = []
metadata = []
for smiles_record in smiles_records:
smiles = smiles_record['smiles']
label = smiles_record['label']
graph = smiles_to_graph(smiles)
graphs.append(graph)
labels.append(label)
metadata.append(smiles_record)
return GraphDataset(graphs=graphs,
labels=labels,
node_variables=atom_features,
edge_variables=bond_features,
metadata=metadata)
We’ll try this out with some simple molecules
dataset = make_molecular_graph_dataset([{'smiles': 'c1ccccc1', 'label':1},
{'smiles':'C=C', 'label': 0}])
dataset[0]
dataset[1]
The DataLoader
We’ve now implemented a dataset which essentially allows us to ask for a single graph from the collection of graphs. The next problem is how to make mini-batches of graphs to feed into our neural networks
Packing graphs
One way of creating minibatches is to actually pack all the separate graphs into a single, multi-component graph. Since a Graph Neural Network uses the graph structure to propagate information, this will have the desired effect of not leaking information across the mini-batch examples. If you have efficient computational kernels for sparse matrix multiplications, this makes a lot of sense and is the approach used in PyTorch Geometric.
In our case, our ultimate goal is to take the step from GNNs to Transformers, and for that to work we can’t rely on the assumption that our neural network will only aggregate informations following the edges. Instead we’ll use a method of packing all information into the same tensor (e.g. all adjacency matrices will be packed into a tensor of size (batch_size, max_nodes, max_nodes).
Masks
When we do this, we also need to account for the graphs being different sizes and will make heavy use of masks to make sure the computation is still doing the correct thing. We will create one node mask which can be used to zero out any computation based on nodes and one edge mask which can be used to zero out computations where edges are the results.
Where to implement the packing
In PyTorch, a DataLoader gives us convenient access to parallell data pre-processing. The dataloader also has an option to supply a collate function, this is where we typically implement the batch packing. Here we’ll implement the collate we’re going to use with our GNN. Note that we should not place them on the target device in this stage, that should be up to the training loop due to multiprocessing interacting poorly with CUDA contexts.
from collections.abc import Set # edges as sets are for undirected graphs
def collate_graph_batch(batch):
'''Collate a batch of graph dictionaries produdced by a GraphDataset'''
batch_size = len(batch)
max_nodes = max(len(graph['nodes']) for graph in batch)
# We start by allocating the tensors we'll use. We defer allocating feature
# tensors until we know the graphs actually has those kinds of features.
adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes),
dtype=float_type)
labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
stacked_continuous_node_features = None
stacked_categorical_node_features = None
stacked_continuous_edge_features = None
stacked_categorical_edge_features = None
nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)
has_metadata = False
for i, graph in enumerate(batch):
if 'metadata' in graph:
has_metadata = True
# We'll take basic information about the different graphs from the adjacency
# matrix
adjacency_matrix = graph['adjacency_matrix']
g_nodes, g_nodes = adjacency_matrix.shape
adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix
# Now when we know how many of the entries are valid, we set those to 1s in
# the masks
edge_mask[i, :g_nodes, :g_nodes] = 1
nodes_mask[i, :g_nodes] = 1
# All the feature constructions follow the same recipie. We essentially
# locate the entries in the stacked feature tensor (containing all graphs)
# and set it with the features from the current graph.
g_continuous_node_features = graph['continuous_node_features']
if g_continuous_node_features is not None:
if stacked_continuous_node_features is None:
g_nodes, num_features = g_continuous_node_features.shape
stacked_continuous_node_features = torch.zeros((batch_size, max_nodes,
num_features))
stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features
g_categorical_node_features = graph['categorical_node_features']
if g_categorical_node_features is not None:
if stacked_categorical_node_features is None:
g_nodes, num_features = g_categorical_node_features.shape
stacked_categorical_node_features = torch.zeros((batch_size, max_nodes,
num_features),
dtype=categorical_type)
stacked_categorical_node_features[i, :g_nodes] =\
g_categorical_node_features
g_continuous_edge_features = graph['continuous_edge_features']
if g_continuous_edge_features is not None:
if stacked_continuous_edge_features is None:
g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes,
max_nodes,
num_features))
stacked_continuous_edge_features[i, :g_nodes, :g_nodes] =\
g_continuous_edge_features
g_categorical_edge_features = graph['categorical_edge_features']
if g_categorical_edge_features is not None:
if stacked_categorical_edge_features is None:
g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes,
max_nodes,
num_features),
dtype=categorical_type)
stacked_categorical_edge_features[i, :g_nodes, :g_nodes] =\
g_categorical_edge_features
batch_record = {'adjacency_matrices': adjacency_matrices,
'categorical_node_features': stacked_categorical_node_features,
'continuous_node_features': stacked_continuous_node_features,
'categorical_edge_features': stacked_categorical_edge_features,
'continuous_edge_features': stacked_continuous_edge_features,
'nodes_mask': nodes_mask,
'edge_mask': edge_mask,
'labels': labels}
if has_metadata:
batch_record['metadata'] = [g['metadata'] for g in batch]
return batch_record
# Here's an example of how these batches looks
dataloader_test_batch = collate_graph_batch([dataset[0], dataset[1]])
dataloader_test_batch
Experiment dataset
To illustrate training a neural network on graphs, we’ll use a molecular property prediction task. This is commonly used as a way of screening molecules in silico for properties of candidate molecules in drug development. The task we’ll look at is called Blood Brain Barrier Penetration, and we’ll use a dataset from the MoleculeNet benchmark suite (https://moleculenet.org) originally assembled by Martins et. al. [Martins, I. F.; Teixeira, A. L.; Pinheiro, L.; Falcao, A. O. Journal of Chemical Information and Modeling 2012, 52, 1686–1697]
Blood Brain Barrier Penetration
An important property to look at for molecules in drug development is whether they can of pass through the membrane seperating the blood stream from the brain extracellular fluid. This barrier blocks many molecules and for drugs which have the purpose of targeting the central nervous system it’s important to find those which can pass through the barrier. For drugs which should not penetrate the barrier it’s important to make sure they do not.
from pathlib import Path
import pandas as pd
potential_paths = [Path('BBBP.csv'), Path('../dataset/BBBP.csv')]
bbbp_table = None
for p in potential_paths:
if p.exists():
bbbp_table = pd.read_csv(p)
break
bbbp_table
The dataset is in a simple CSV. The columns we want to use is ‘smiles’ and ‘p_np’. Many datasets like this unfortunately contain malformed smiles (the chemistry of the SMILES isn’t valid because some fragments might have been removed during pre-processing). In this dataset it’s relatively few, and we’ll filter the dataset while we’re doing the splitting. RDKit will tell us when molecules are invalid by returning None
from the MolFromSmiles
function.
labeled_smiles_list = bbbp_table[['smiles', 'p_np']]
labeled_smiles_list
Data splitting
In this example we will just randomly split the dataset. For the more advanced scaffold split strategy recommended for the MoleculeNet benchmark, you can look at the notebook GNN_02-Appendix.ipynb
.
# There are about 50 problematic SMILES, RDKit will give you about as many ERRORs
# and WARNINGs, this is expected behaviour. We supress these warnings here
# 11 of the SMILES can't be parsed at all
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
smiles_records = []
for i, num, name, p_np, smiles in bbbp_table.to_records():
# check if RDKit accepts this smiles
if MolFromSmiles(smiles) is not None:
smiles_record = {'smiles': smiles, 'label': p_np, 'metadata': {'row': i}}
smiles_records.append(smiles_record)
else:
print(f'Molecule {smiles} on row {i} could not be parsed by RDKit')
import random
random.seed(1729)
training_fraction = 0.8
dev_fraction = 0.1
n_examples = len(smiles_records)
n_training_examples = int(n_examples*training_fraction)
n_dev_examples = int(n_examples*dev_fraction)
indices = list(range(n_examples))
random.shuffle(indices) # shuffle is in place
training_indices = indices[:n_training_examples]
dev_indices = indices[n_training_examples:n_training_examples+n_dev_examples]
test_indices = indices[n_training_examples+n_dev_examples:]
training_smiles_records = [smiles_records[i] for i in training_indices]
dev_smiles_records = [smiles_records[i] for i in dev_indices]
test_smiles_records = [smiles_records[i] for i in test_indices]
training_smiles_records[:10]
Building the datasets
We now have all the information needed to build our dataset splits. We’ll create the graphs from the smiles of these recrords, but also save the original smiles and row in the bbbp_table dataset if we are interested in going back from a specific graph and see where it comes from (e.g. figure out if a missprediction can be explained by domain knowledge).
training_graph_dataset = make_molecular_graph_dataset(training_smiles_records)
dev_graph_dataset = make_molecular_graph_dataset(dev_smiles_records)
test_graph_dataset = make_molecular_graph_dataset(test_smiles_records)
We now have hour dataset splits loaded and ready in our graph dataset. We can have a look at one of the examples. See if you can identify what parts of our GraphDataset representation belongs the pictoral representation
example_graph = training_graph_dataset[0]
example_graph
rdmol = MolFromSmiles(example_graph['metadata']['smiles'])
rdmol
Questions
Answer these questions:
Why are the features returned as separate values in
__getitem__
of the dataset class?The dataset returns a 2D array for the node features, what is the shape (dimensionality of the axises) for this array?
The dataset returns a 3D array for the edge features, what is the shape (dimensionality of the axises) for this array?
How are node feature matrices with different number of rows combined in the
collate_graph_batch()
function?Why do we need mask arrays? How could we solve this problem without using masks?
The dataset returns the edge features as dense 3D arrays, when is this a bad idea in practice? Is it ever a good idea?
Optional coding task
The dataset returns the edge features as dense 3D arrays. Implement an alternative dataset and collate function which instead uses a sparse representation.
The dataset will need to return two separate datastructure to do this:
Edge attributes matrix: A matrix where each edge feature vector is a row. It should have as many rows as edges in the graph. The order should be the same as in the next data structure
Node coordinate list: A matrix with two columns which gives the node indices for the edge feature vector in the edge attribute matrix
The collate function needs to combine these datastructures into batches. Will you need to add additional masks for these datastructures?
class SparseEdgesGraphDataset(GraphDataset):
def make_continuous_edge_features(self, n_nodes, edges):
if len(self.continuous_edge_variables) == 0:
return None
n_features = len(self.continuous_edge_variables)
# Implement sparse edge features here
return continuous_edge_features
def make_categorical_edge_features(self, n_nodes, edges):
if len(self.categorical_edge_variables) == 0:
return None
n_features = len(self.categorical_edge_variables)
# Implement sparse edge features here
return categorical_edge_features
from collections.abc import Set # edges as sets are for undirected graphs
def collate_sparse_graph_batch(batch):
'''Collate a batch of graph dictionaries produdced by a GraphDataset'''
batch_size = len(batch)
max_nodes = max(len(graph['nodes']) for graph in batch)
# We start by allocating the tensors we'll use. We defer allocating feature
# tensors until we know the graphs actually has those kinds of features.
adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes),
dtype=float_type)
labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
stacked_continuous_node_features = None
stacked_categorical_node_features = None
sparse_continuous_edge_features = None
sparse_categorical_edge_features = None
nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)
has_metadata = False
for i, graph in enumerate(batch):
if 'metadata' in graph:
has_metadata = True
# We'll take basic information about the different graphs from the adjacency
# matrix
adjacency_matrix = graph['adjacency_matrix']
g_nodes, g_nodes = adjacency_matrix.shape
adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix
edge_mask[i, :g_nodes, :g_nodes] = 1
nodes_mask[i, :g_nodes] = 1
g_continuous_node_features = graph['continuous_node_features']
if g_continuous_node_features is not None:
if stacked_continuous_node_features is None:
g_nodes, num_features = g_continuous_node_features.shape
stacked_continuous_node_features = torch.zeros((batch_size, max_nodes,
num_features))
stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features
g_categorical_node_features = graph['categorical_node_features']
if g_categorical_node_features is not None:
if stacked_categorical_node_features is None:
g_nodes, num_features = g_categorical_node_features.shape
stacked_categorical_node_features = torch.zeros((batch_size, max_nodes,
num_features),
dtype=categorical_type)
stacked_categorical_node_features[i, :g_nodes] =\
g_categorical_node_features
# Here you need to figure out how to combine the sparse features
# Do you need a seperate mask for these or can you make it work without?
batch_record = {'adjacency_matrices': adjacency_matrices,
'categorical_node_features': stacked_categorical_node_features,
'continuous_node_features': stacked_continuous_node_features,
'categorical_edge_features': sparse_categorical_edge_features,
'continuous_edge_features': sparse_continuous_edge_features,
'nodes_mask': nodes_mask,
'edge_mask': edge_mask,
'labels': labels}
if has_metadata:
batch_record['metadata'] = [g['metadata'] for g in batch]
return batch_record
Network and Training
The remainder of the cells in the notebook are the same as in the previous notebook. We leave them here if you wan’t to try out training the network.
The Graph Neural Network
We have now built the data pipelines for our graph data, all the way to creating tensors out of the graphs and packing them into batches
We’ll now briefly create a Graph Neural Network. This will be a basic network which does not make use of the edge features, only the adjacency. We will cover the details of this architecture as well as extensions in the next notebook session.
from torch.nn import Module, Embedding, ModuleList, Linear, Sequential, ReLU
from torch.nn import LayerNorm, Dropout
from torch.nn.functional import layer_norm, dropout
class BasicGNNConfig:
def __init__(self, *,
d_model: int,
n_layers: int,
residual_connections: bool,
ffn_dim: int,
layer_norm: bool,
prediction_head_dim: int,
dropout_rate: float = 0,
graph_pooling: str = 'mean'):
self.d_model = d_model
self.n_layers = n_layers
self.residual_connections = residual_connections
self.ffn_dim = ffn_dim
self.prediction_head_dim = prediction_head_dim
self.layer_norm = layer_norm
self.dropout_rate = dropout_rate
self.graph_pooling = graph_pooling
class GraphLayer(Module):
def __init__(self, config):
super().__init__()
self.config = config
self.neighbour_transform = Linear(config.d_model, config.ffn_dim,
bias=False) # Since we are going to
# aggregate the result of these transformations, we skip the bias term
self.center_node_transform = Linear(config.d_model, config.ffn_dim)
self.output_transform = Linear(config.ffn_dim, config.d_model)
self.dropout = Dropout(config.dropout_rate)
self.layer_norm = LayerNorm(config.d_model)
self.nonlinearity = ReLU()
def forward(self, memory_state, adjacency_matrices, edge_features=None,
node_mask=None, edge_mask=None):
neighbour_state = self.neighbour_transform(memory_state)
center_state = self.center_node_transform(memory_state)
# This is the heart of the GNN layer. By doing a batched multiplication with
# the adjacency matricies were essentially summing all the vectors of a
# neighbourhood for all the nodes
aggregated_neighbourhood = torch.bmm(adjacency_matrices, neighbour_state)\
+ center_state
transformed_state = self.nonlinearity(aggregated_neighbourhood)
updated_memory_state = self.output_transform(transformed_state)
if self.config.dropout_rate > 0:
updated_memory_state = self.dropout(updated_memory_state)
if self.config.layer_norm:
updated_memory_state = self.layer_norm(updated_memory_state)
masked_memory_state = updated_memory_state * node_mask.unsqueeze(-1)
return masked_memory_state
class BasicGNN(torch.nn.Module):
def __init__(self, *, output_dim,
config: BasicGNNConfig,
continuous_node_variables=None,
categorical_node_variables=None,
continuous_edge_variables=None,
categorical_edge_variables=None,
layer_type=GraphLayer):
super().__init__()
self.output_dim = output_dim
self.config = config
self.layer_type = layer_type
if categorical_node_variables is None:
categorical_node_variables = []
if continuous_node_variables is None:
continuous_node_variables = []
if continuous_edge_variables is None:
continuous_edge_variables = []
if categorical_edge_variables is None:
categorical_edge_variables = []
self.continuous_node_variables = continuous_node_variables
self.categorical_node_variables = categorical_node_variables
self.continuous_edge_variables = continuous_edge_variables
self.categorical_edge_variables = categorical_edge_variables
# We want the embeddings together with the continuous values to be of
# dimension d_model, therefore the allocate
# d_model - len(continuous_variables) as the embeddings dim
self.categorical_node_embeddings_dim = config.d_model\
- len(self.continuous_node_variables)
self.categorical_edge_embeddings_dim = config.d_model\
- len(self.continuous_edge_variables)
self.node_embeddings = ModuleList([Embedding(len(var),
self.categorical_node_embeddings_dim) for var in
self.categorical_node_variables])
self.edge_embeddings = ModuleList([Embedding(len(var),
self.categorical_edge_embeddings_dim) for var in
self.categorical_edge_variables])
# Notice that we use the supplied layer type above when creating the graph
# layers. This allows us to easily change the kind of graph layers
# we use later on
self.graph_layers = ModuleList([layer_type(config) for l in
range(config.n_layers)])
self.prediction_head = Sequential(Linear(config.d_model,
config.prediction_head_dim),
ReLU(),
Linear(config.prediction_head_dim,
output_dim))
def forward(self, batch):
# First order of business is to embed the node embeddings
node_mask = batch['nodes_mask']
batch_size, max_nodes = node_mask.shape
continuous_node_features = batch['continuous_node_features']
categorical_node_features = batch['categorical_node_features']
if categorical_node_features is not None:
node_embeddings = []
for i, embedding in enumerate(self.node_embeddings):
node_features = categorical_node_features[:, :, i]
embedded_features = embedding(node_features)
node_embeddings.append(embedded_features)
node_features = torch.sum(torch.stack(node_embeddings, dim=-1), dim=-1)
else:
node_features = torch.zeros((batch_size, max_nodes,
self.categorical_node_embeddings_dim))
if continuous_node_features is not None:
# We need to make sure the continuous embeddings are valid
node_features = torch.cat([node_features, continuous_node_features],
dim=-1)
# The node features are now set. However, there will be invalid entries for
# graphs smaller than the one of maximum size in the batch. We use the
# node mask to zero out any such entries. The node features is of shape
# (batch_size, max_nodes, d_model). We want to broadcast the node mask along
# the last axis (the featur vectors), so we add a trailing dimension of size
# 1
masked_node_features = node_features * node_mask.unsqueeze(-1)
# We have now embedded the node features, we'll propagate them through our
# graph layers
adjacency_matrix = batch['adjacency_matrices']
memory_state = node_features
for l in self.graph_layers:
if self.config.residual_connections:
memory_state = memory_state + l(memory_state, adjacency_matrix,
node_mask=node_mask)
else:
memory_state = l(memory_state, adjacency_matrix, node_mask=node_mask)
if self.config.graph_pooling == 'mean' or\
self.config.graph_pooling == 'sum':
# The memory state has shape (batch_dim, max_nodes, d_model).
# Since the graphs have different shape, we need to use the node mask to
# calculate the mean.
summed_memory_state = memory_state.sum(dim=1) # Dim=1 is the node dimens.
if self.config.graph_pooling == 'sum':
memory_state = summed_memory_state
else:
n_nodes = node_mask.sum(dim=1)
mean_memory_state = summed_memory_state / n_nodes.unsqueeze(-1)
memory_state = mean_memory_state
prediction = self.prediction_head(memory_state)
return prediction
Training the model
We’ve put a lot of infrastructure in place, and are finally at the stage where we will train the model. It’s important to note that we will need to move all the tensors in the batches to the device (in this case we default to ‘cuda’, we assume this runs on a colab GPU instance). If not, you need to change the runtime (under Runtime->Change runtime type) and set ‘accelerator’ to GPU.
The training loop
To make training models a bit easier, we encapsulate most of the training loop and its state in a Trainer class.
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print("Device is", device)
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss, MSELoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, mean_squared_error
from sklearn.metrics import mean_absolute_error, median_absolute_error
import matplotlib.pyplot as plt
import seaborn as sns
def batch_to_device(batch, device):
moved_batch = {}
for k, v in batch.items():
if torch.is_tensor(v):
v = v.to(device)
moved_batch[k] = v
return moved_batch
class Trainer:
def __init__(self, *, model,
loss_fn, training_dataloader,
dev_dataloader, device=device):
self.model = model
self.training_dataloader = training_dataloader
self.dev_dataloader = dev_dataloader
self.device = device
self.model.to(device)
self.total_epochs = 0
self.optimizer = AdamW(self.model.parameters(), lr=1e-4)
self.loss_fn = loss_fn
def train(self, epochs):
with trange(epochs, desc='Epoch', position=0) as epoch_progress:
batches_per_epoch = len(self.training_dataloader)\
+ len(self.dev_dataloader)
for epoch in epoch_progress:
train_loss = 0
train_n = 0
for i, training_batch in enumerate(tqdm(self.training_dataloader,
desc='Training batch',
leave=False)):
self.optimizer.zero_grad()
# Move all tensors to the device
self.model.train()
training_batch = batch_to_device(training_batch, self.device)
prediction = self.model(training_batch)
labels = training_batch['labels']
loss = self.loss_fn(prediction.squeeze(), labels) # By default the
# predictions have shape (batch_size, 1)
loss.backward()
self.optimizer.step()
batch_n = len(labels)
train_loss += batch_n * loss.cpu().item()
train_n += batch_n
#print(f"Training loss for epoch {total_epochs}", train_loss/train_n)
self.total_epochs += 1
dev_predictions = []
dev_labels = []
dev_n = 0
dev_loss = 0
for i, dev_batch in enumerate(tqdm(self.dev_dataloader,
desc="Dev batch", leave=False)):
self.model.eval()
with torch.no_grad():
dev_batch = batch_to_device(dev_batch, self.device)
prediction = self.model(dev_batch).squeeze()
dev_predictions.extend(prediction.tolist())
labels = dev_batch['labels']
dev_labels.extend(labels.tolist())
loss = self.loss_fn(prediction, labels) # By default the
# predictions have shape (batch_size, 1)
batch_n = len(labels)
dev_loss += batch_n*loss.cpu().item()
dev_n += batch_n
epoch_progress.set_description(f"Epoch: train loss\
{train_loss/train_n: .3f}, dev loss {dev_loss/dev_n: .3f}")
def evaluate_model(trainer, dataloader, label=None, hue_order=[0,1]):
eval_predictions = []
eval_labels = []
eval_loss = 0
eval_n = 0
model = trainer.model
loss_fn = trainer.loss_fn
total_epochs = trainer.total_epochs
for i, eval_batch in enumerate(tqdm(dataloader, desc='batch')):
model.eval()
with torch.no_grad():
eval_batch = batch_to_device(eval_batch, device)
prediction = model(eval_batch).squeeze()
eval_predictions.extend(prediction.tolist())
labels = eval_batch['labels']
eval_labels.extend(labels.tolist())
loss = loss_fn(prediction, labels) # By default the predictions have
# shape (batch_size, 1)
batch_n = len(labels)
eval_loss += batch_n*loss.cpu().item()
eval_n += batch_n
average_loss = eval_loss/eval_n
roc_auc = roc_auc_score(eval_labels, eval_predictions)
eval_df = pd.DataFrame(data={'target': eval_labels,
'predictions': eval_predictions})
sns.kdeplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
sns.rugplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
if label is not None:
title = f"{label} dataset after {total_epochs} epochs\nloss {average_loss} \nROC AUC {roc_auc}"
else:
title = f"After {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
plt.title(title)
Model configuration
The cell below contains the settings for the model. In particular, the config = BasicGNNConfig(...)
is where we set up the hyper parameters.
d_model
: The dimensionality of the vectors in the model, essentially the capacity of the Graph Neural Networkn_layers
: The number of GNN layers in the networkresidual_connections
: Toggles whether residual connections are used. It is often a good idea to use residual connections in GNNs, especially if you make them deepffn_dim
: The network contains small 2-layer MLPs between each graph layer, this controls the dimensionality of the “hidden” layer of that MLP. Setting this to a value larger thand_model
is often done in practice.prediction_head_dim
: After the final graph aggregation, a 2-layer MLP will transform the output to our graph prediction, this controls the dimensionality of the hidden layer of that MLP.layer_norm
: Toggles whether layer norm is used. Can help to improve training and generalization.dropout_rate
: The dropout rate to use in the network. 0 means no dropout.graph_pooling
: The pooling method to use on the output of the last GNN layer, supported aresum
andmean
batch_size = 32
num_dataloader_workers = 2 # colab instances are very limited in number of cpus
training_dataloader = DataLoader(training_graph_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_dataloader_workers,
collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
torch.manual_seed(1729)
# The config below contains most of the hyper parameters you might be interested in tuning
config = BasicGNNConfig(d_model=32,
n_layers=2,
residual_connections=False,
ffn_dim=32,
prediction_head_dim=32,
layer_norm=True,
dropout_rate=0.2,
graph_pooling='sum')
model = BasicGNN(output_dim = 1,
config=config,
continuous_node_variables=training_graph_dataset.continuous_node_variables,
categorical_node_variables=training_graph_dataset.categorical_node_variables,
continuous_edge_variables=training_graph_dataset.continuous_edge_variables,
categorical_edge_variables=training_graph_dataset.categorical_edge_variables)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_dataloader,
dev_dataloader=dev_dataloader)
Note that the trainer function will not reset the model. Instead, it will take the existing model, and continues training for the specified number of epochs. If you do want to train the model from scratch, re-run the above code block, as it will re-initialize the model with random weights.
# The argument is the number of epochs to train, note that the model does not
# reset between calls, so you can train one epoch at a time
trainer.train(1)
Looking at the model predictions
To evaluate our models, we can look at how well it actually is able to separate the positive and the negative class. To get a better view of this than just a summary statistic (like the ROC AUC score) we have defined a function above called evaluate_model
which takes our trainer as input. It shows us a rugplot and kernel density estimation of the distribution of model predictions based on the supplied dataloader, separated by class.
Predictions on the training dataset
evaluate_model(trainer, training_dataloader, label='Train')
This plot shows us the prediction score (logit) our model assigns the training examples, colored by their class. A kernel density estimator is fitted to give us an idea of how the logits for the training data is distributed conditioned on class.
If you’ve only trained a couple of epochs, the model will likely have a hard time separating the data points.
Predictions on the dev set
evaluate_model(trainer, dev_dataloader, label='Dev')
Performance on the dev set looks quite similar, so we’re unlikely to be overfitting just yet
Predictions on the test set
Here’s how you can evaluate the model on the test set once you are done developing it
evaluate_model(trainer, test_dataloader, label='Dev')
Task
We’ve trained a Graph Neural Network using a single set of hyper parameters. Experiment with different hyper parameters and see how good you can make the predictions. You can find the hyper parameters under the “Model configuration” subsection above, in the cell above looking like this
config = BasicGNNConfig(d_model=32, n_layers=4, residual_connections=True,
ffn_dim=32,
prediction_head_dim=32,
layer_norm=True,
dropout_rate=0.2,
graph_pooling='sum')
What settings do you need to overfit and can you find a sweet spot where you don’t?
Keep in mind that this is a difficult problem on real data, so perfect separation is not to be expected.