GNNs to Transformers
We’ve previously seen how we can implement graph convolutions using multiplication with the adjacency matrix. In this notebook we will look at how this can be extended to a Transformer by computing an adjacency matrix.
You can find this notebook on Google Colab here
Code from previous notebooks
We will use the same example as in the previous notebooks.
from collections import defaultdict
from collections.abc import Set
import rdkit
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG=True #< set this to False if you want PNGs instead of SVGs
IPythonConsole.drawOptions.addAtomIndices = True # This will help when looking at the Mol graph representation
IPythonConsole.molSize = 600, 600
# We supress RDKit errors for this notebook
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
import torch
from torch.utils.data import Dataset, DataLoader
float_type = torch.float32 # We're hardcoding types in the tensors further down
categorical_type = torch.long
mask_type = torch.float32 # We're going to be multiplying our internal calculations with a mask using this type
labels_type = torch.float32 # We're going to use BCEWithLogitsLoss, which expects the labels to be of the same type as the predictions
class ContinuousVariable:
def __init__(self, name):
self.name = name
def __repr__(self):
return f'<ContinuousVariable: {self.name}>'
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
class CategoricalVariable:
def __init__(self, name, values, add_null_value=True):
self.name = name
self.has_null_value = add_null_value
if self.has_null_value:
self.null_value = None
values = (None,) + tuple(values)
self.values = tuple(values)
self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
self.inv_value_to_idx_mapping = {i: v for v, i in self.value_to_idx_mapping.items()}
if self.has_null_value:
self.null_value_idx = self.value_to_idx_mapping[self.null_value]
def get_null_idx(self):
if self.has_null_value:
return self.null_value_idx
else:
raise RuntimeError(f"Categorical variable {self.name} has no null value")
def value_to_idx(self, value):
return self.value_to_idx_mapping[value]
def idx_to_value(self, idx):
return self.inv_value_to_idx_mapping[idx]
def __len__(self):
return len(self.values)
def __repr__(self):
return f'<CategoricalVariable: {self.name}>'
def __eq__(self, other):
return self.name == other.name and self.values == other.values
def __hash__(self):
return hash((self.name, self.values))
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti',
'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As',
'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru',
'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs',
'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl',
'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg',
'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce', 'Pr',
'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb',
'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk',
'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']
ATOM_SYMBOLS_FEATURE = CategoricalVariable('atom_symbol', ATOM_SYMBOLS)
ATOM_AROMATIC_VALUES = [True, False]
ATOM_AROMATIC_FEATURE = CategoricalVariable('is_aromatic', ATOM_AROMATIC_VALUES)
# In practice you might like to use categroical features for valence, but we use continuous here for demonstration
ATOM_EXPLICIT_VALENCE_FEATURE = ContinuousVariable('explicit_valence')
ATOM_IMPLICIT_VALENCE_FEATURE = ContinuousVariable('implicit_valence')
ATOM_FEATURES = [ATOM_SYMBOLS_FEATURE, ATOM_AROMATIC_FEATURE, ATOM_EXPLICIT_VALENCE_FEATURE, ATOM_IMPLICIT_VALENCE_FEATURE]
def get_atom_features(rd_atom):
atom_symbol = rd_atom.GetSymbol()
is_aromatic = rd_atom.GetIsAromatic()
implicit_valence = float(rd_atom.GetImplicitValence())
explicit_valence = float(rd_atom.GetExplicitValence())
return {ATOM_SYMBOLS_FEATURE: atom_symbol,
ATOM_AROMATIC_FEATURE: is_aromatic,
ATOM_EXPLICIT_VALENCE_FEATURE: explicit_valence,
ATOM_IMPLICIT_VALENCE_FEATURE: implicit_valence}
# We could use the RDKit enumeration types instead of strings, but the advantage
# of doing it like this is that our representation becomes independent of RDKit
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE',
'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC',
'IONIC', 'HYDROGEN', 'THREECENTER', 'DATIVEONE', 'DATIVE',
'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']
TYPE_FEATURE = CategoricalVariable('bond_type', BOND_TYPES)
BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT', 'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalVariable('bond_direction', BOND_DIRECTIONS)
BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE',
'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalVariable('bond_stereo', BOND_STEREO)
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalVariable('is_aromatic', AROMATIC_VALUES)
BOND_FEATURES = [TYPE_FEATURE, DIRECTION_FEATURE, AROMATIC_FEATURE, STEREO_FEATURE]
def get_bond_features(rd_bond):
bond_type = str(rd_bond.GetBondType())
bond_stereo_info = str(rd_bond.GetStereo())
bond_direction = str(rd_bond.GetBondDir())
is_aromatic = rd_bond.GetIsAromatic()
return {TYPE_FEATURE: bond_type,
DIRECTION_FEATURE: bond_direction,
AROMATIC_FEATURE: is_aromatic,
STEREO_FEATURE: bond_stereo_info}
def rdmol_to_graph(mol):
atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
bonds = {frozenset((rd_bond.GetBeginAtomIdx(), rd_bond.GetEndAtomIdx())): get_bond_features(rd_bond) for rd_bond in mol.GetBonds()}
return atoms, bonds
def smiles_to_graph(smiles):
rd_mol = MolFromSmiles(smiles)
graph = rdmol_to_graph(rd_mol)
return graph
g = smiles_to_graph('c1ccccc1')
class GraphDataset(Dataset):
def __init__(self, *, graphs, labels, node_variables, edge_variables, metadata=None):
'''
Create a new graph dataset,
'''
self.graphs = graphs
self.labels = labels
assert len(self.graphs) == len(self.labels), "The graphs and labels lists must be the same length"
self.metadata = metadata
if self.metadata is not None:
assert len(self.metadata) == len(self.graphs), "The metadata list needs to be as long as the graphs"
self.node_variables = node_variables
self.edge_variables = edge_variables
self.categorical_node_variables = [var for var in self.node_variables if isinstance(var, CategoricalVariable)]
self.continuous_node_variables = [var for var in self.node_variables if isinstance(var, ContinuousVariable)]
self.categorical_edge_variables = [var for var in self.edge_variables if isinstance(var, CategoricalVariable)]
self.continuous_edge_variables = [var for var in self.edge_variables if isinstance(var, ContinuousVariable)]
def __len__(self):
return len(self.graphs)
def make_continuous_node_features(self, nodes):
if len(self.continuous_node_variables) == 0:
return None
n_nodes = len(nodes)
n_features = len(self.continuous_node_variables)
continuous_node_features = torch.zeros((n_nodes, n_features), dtype=float_type)
for node_idx, features in nodes.items():
node_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_node_variables], dtype=float_type)
continuous_node_features[node_idx] = node_features
return continuous_node_features
def make_categorical_node_features(self, nodes):
if len(self.categorical_node_variables) == 0:
return None
n_nodes = len(nodes)
n_features = len(self.categorical_node_variables)
categorical_node_features = torch.zeros((n_nodes, n_features), dtype=categorical_type)
for node_idx, features in nodes.items():
for i, categorical_variable in enumerate(self.categorical_node_variables):
value = features[categorical_variable]
value_index = categorical_variable.value_to_idx(value)
categorical_node_features[node_idx, i] = value_index
return categorical_node_features
def make_continuous_edge_features(self, n_nodes, edges):
if len(self.continuous_edge_variables) == 0:
return None
n_features = len(self.continuous_edge_variables)
continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=float_type)
for edge, features in edges.items():
edge_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_edge_variables], dtype=float_type)
u,v = edge
continuous_edge_features[u, v] = edge_features
if isinstance(edge, Set):
continuous_edge_features[v, u] = edge_features
return continuous_edge_features
def make_categorical_edge_features(self, n_nodes, edges):
if len(self.categorical_edge_variables) == 0:
return None
n_features = len(self.categorical_edge_variables)
categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=categorical_type)
for edge, features in edges.items():
u,v = edge
for i, categorical_variable in enumerate(self.categorical_edge_variables):
value = features[categorical_variable]
value_index = categorical_variable.value_to_idx(value)
categorical_edge_features[u, v, i] = value_index
if isinstance(edge, Set):
categorical_edge_features[v, u, i] = value_index
return categorical_edge_features
def __getitem__(self, index):
# This is where the important stuff happens. We use our node and
# edge variable attributes to select what node and edge features to use.
# In practice, we often do this as a pre-processing step, but here we do it
# in the getitem function for clarity
graph = self.graphs[index]
nodes, edges = graph
n_nodes = len(nodes)
continuous_node_features = self.make_continuous_node_features(nodes)
categorical_node_features = self.make_categorical_node_features(nodes)
continuous_edge_features = self.make_continuous_edge_features(n_nodes, edges)
categorical_edge_features = self.make_categorical_edge_features(n_nodes, edges)
label = self.labels[index]
nodes_idx = sorted(nodes.keys())
edge_list = sorted(edges.keys())
n_nodes = len(nodes)
adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
for edge in edges:
u, v = edge
adjacency_matrix[u,v] = 1
if isinstance(edge, Set):
# This edge is unordered, assume this is a undirected graph
adjacency_matrix[v,u] = 1
adjacency_list = defaultdict(list)
for edge in edges:
u,v = edge
adjacency_list[u].append(v)
# Assume undirected graph is the edge is a set
if isinstance(edge, Set):
adjacency_list[v].append(u)
data_record = {'nodes': nodes_idx,
'adjacency_matrix': adjacency_matrix,
'adjacency_list': adjacency_list,
'categorical_node_features': categorical_node_features,
'continuous_node_features': continuous_node_features,
'categorical_edge_features': categorical_edge_features,
'continuous_edge_features': continuous_edge_features,
'label': label}
# If you need to add extra information (metadata about this graph) you can
# add an extra key-value pair here. The advantage of using a dict compared
# to a tuple is that the downstreams code doesn't break as long as at least
# the expected keys are present. The downside is that using a dict adds
# overhead (accessing a dict compared to unpacking a tuple).
# A more robust implementation might actually make a separate class for
# dataset entires
if self.metadata is not None:
data_record['metadata'] = self.metadata[index]
return data_record
def get_node_variables(self):
return {'continuous': self.continuous_node_variables,
'categorical': self.categorical_node_variables}
def get_edge_variables(self):
return {'continuous': self.continuous_edge_variables,
'categorical': self.categorical_edge_variables}
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES, bond_features=BOND_FEATURES):
'''
Create a new GraphDataset from a list of smiles_records dictionaries.
These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
'''
graphs = []
labels = []
metadata = []
for smiles_record in smiles_records:
smiles = smiles_record['smiles']
label = smiles_record['label']
graph = smiles_to_graph(smiles)
graphs.append(graph)
labels.append(label)
metadata.append(smiles_record)
return GraphDataset(graphs=graphs,
labels=labels,
node_variables=atom_features,
edge_variables=bond_features,
metadata=metadata)
dataset = make_molecular_graph_dataset([{'smiles': 'c1ccccc1', 'label':1},{'smiles':'OS(=O)(=O)O', 'label': 0}])
from collections.abc import Set # We assume that edges as sets are for undirected graphs
def collate_graph_batch(batch):
'''Collate a batch of graph dictionaries produdce by a GraphDataset'''
batch_size = len(batch)
max_nodes = max(len(graph['nodes']) for graph in batch)
# We start by allocating the tensors we'll use. We defer allocating feature
# tensors until we know the graphs actually has those kinds of features.
adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes), dtype=float_type)
labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
stacked_continuous_node_features = None
stacked_categorical_node_features = None
stacked_continuous_edge_features = None
stacked_categorical_edge_features = None
nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)
has_metadata = False
for i, graph in enumerate(batch):
if 'metadata' in graph:
has_metadata = True
# We'll take basic information about the different graphs from the adjacency
# matrix
adjacency_matrix = graph['adjacency_matrix']
g_nodes, g_nodes = adjacency_matrix.shape
adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix
# Now when we know how many of the entries are valid, we set those to 1s in
# the masks
edge_mask[i, :g_nodes, :g_nodes] = 1
nodes_mask[i, :g_nodes] = 1
# All the feature constructions follow the same recipie. We essentially
# locate the entries in the stacked feature tensor (containing all graphs)
# and set it with the features from the current graph.
g_continuous_node_features = graph['continuous_node_features']
if g_continuous_node_features is not None:
if stacked_continuous_node_features is None:
g_nodes, num_features = g_continuous_node_features.shape
stacked_continuous_node_features = torch.zeros((batch_size, max_nodes, num_features))
stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features
g_categorical_node_features = graph['categorical_node_features']
if g_categorical_node_features is not None:
if stacked_categorical_node_features is None:
g_nodes, num_features = g_categorical_node_features.shape
stacked_categorical_node_features = torch.zeros((batch_size, max_nodes, num_features), dtype=categorical_type)
stacked_categorical_node_features[i, :g_nodes] = g_categorical_node_features
g_continuous_edge_features = graph['continuous_edge_features']
if g_continuous_edge_features is not None:
if stacked_continuous_edge_features is None:
g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features))
stacked_continuous_edge_features[i, :g_nodes, :g_nodes] = g_continuous_edge_features
g_categorical_edge_features = graph['categorical_edge_features']
if g_categorical_edge_features is not None:
if stacked_categorical_edge_features is None:
g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features), dtype=categorical_type)
stacked_categorical_edge_features[i, :g_nodes, :g_nodes] = g_categorical_edge_features
batch_record = {'adjacency_matrices': adjacency_matrices,
'categorical_node_features': stacked_categorical_node_features,
'continuous_node_features': stacked_continuous_node_features,
'categorical_edge_features': stacked_categorical_edge_features,
'continuous_edge_features': stacked_continuous_edge_features,
'nodes_mask': nodes_mask,
'edge_mask': edge_mask,
'labels': labels}
if has_metadata:
batch_record['metadata'] = [g['metadata'] for g in batch]
return batch_record
example_batch = collate_graph_batch([dataset[0], dataset[1]])
from torch.nn import Module
class Embedder(Module):
def __init__(self, categorical_variables, embedding_dim):
super().__init__()
self.categorical_variables = categorical_variables
embeddings = []
for var in categorical_variables:
num_embeddings = len(var)
if var.has_null_value:
# It's not uncommon to have missing values, we support this assinging a special 0-index which have the zero-vector as its embedding
embedding = Embedding(num_embeddings, embedding_dim, padding_idx=var.get_null_idx())
else:
embedding = Embedding(num_embeddings, embedding_dim)
embeddings.append(embedding)
self.embeddings = ModuleList(embeddings)
def forward(self, categorical_features):
# The node features is a matrix with as many rows as nodes of our graph
# and as many columns as we have categorical features
all_embedded_vars = []
for i, embedding in enumerate(self.embeddings):
# We pick out just the i'th column. The ellipsis '...' in a numpy-style
# slice is a useful way of saying you want full range over all other axises
# We use it so that this can actually take a categorical_features array
# with arbitrary number of trailing axises to support both the node
# features, the edge features and the mini-batched version of both
var_indices = categorical_features[..., i]
embedded_vars = embedding(var_indices)
all_embedded_vars.append(embedded_vars)
# If you like, you can implement concatenation instead of sum here
stacked_embedded_vars = torch.stack(all_embedded_vars, dim=0)
embedded_vars = torch.sum(stacked_embedded_vars, dim=0)
return embedded_vars
class FeatureCombiner(Module):
def __init__(self, categorical_variables, embedding_dim):
super().__init__()
self.categorical_variables = categorical_variables
self.embedder = Embedder(self.categorical_variables, embedding_dim)
def forward(self, continuous_features, categorical_features, ):
# We need to be agnostic to whether we have categorical features and continuous features (it's not uncommon to only use one kind)
features = []
if categorical_features is not None:
embedded_features = self.embedder(categorical_features)
features.append(embedded_features)
# The embedded features are now of shape (n_nodes, embedding_dim)
if continuous_features is not None:
features.append(continuous_features)
if len(features) == 0:
raise RuntimeError('No features to combine')
full_features = torch.cat(features, dim=-1) # Now we concatenate along the feature dimension
return full_features
Training step
This section downloads the data and implements the training procedures. This code should also be familiar from the previous lessons.
from collections import defaultdict, deque # We'll use this to construct the dataset splits
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmilesFromSmiles
from pathlib import Path
import pandas as pd
potential_paths = [Path('BBBP.csv'), Path('../dataset/BBBP.csv')]
bbbp_table = None
for p in potential_paths:
if p.exists():
bbbp_table = pd.read_csv(p)
break
bbbp_table
num | name | p_np | smiles | |
---|---|---|---|---|
0 | 1 | Propanolol | 1 | [Cl].CC(C)NCC(O)COc1cccc2ccccc12 |
1 | 2 | Terbutylchlorambucil | 1 | C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl |
2 | 3 | 40730 | 1 | c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO... |
3 | 4 | 24 | 1 | C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C |
4 | 5 | cloxacillin | 1 | Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)... |
... | ... | ... | ... | ... |
2045 | 2049 | licostinel | 1 | C1=C(Cl)C(=C(C2=C1NC(=O)C(N2)=O)[N+](=O)[O-])Cl |
2046 | 2050 | ademetionine(adenosyl-methionine) | 1 | [C@H]3([N]2C1=C(C(=NC=N1)N)N=C2)[C@@H]([C@@H](... |
2047 | 2051 | mesocarb | 1 | [O+]1=N[N](C=C1[N-]C(NC2=CC=CC=C2)=O)C(CC3=CC=... |
2048 | 2052 | tofisoline | 1 | C1=C(OC)C(=CC2=C1C(=[N+](C(=C2CC)C)[NH-])C3=CC... |
2049 | 2053 | azidamfenicol | 1 | [N+](=NCC(=O)N[C@@H]([C@H](O)C1=CC=C([N+]([O-]... |
2050 rows × 4 columns
# There are about 50 problematic SMILES. We've supressed RDKit logger outputs
# but if you haven't you'll get a lot of printouts here
# 11 SMILES can't be processed at all so we throw them away
smiles_records = []
for i, num, name, p_np, smiles in bbbp_table.to_records():
# check if RDKit accepts this smiles
if MolFromSmiles(smiles) is not None:
smiles_record = {'smiles': smiles, 'label': p_np, 'metadata': {'row': i}}
smiles_records.append(smiles_record)
else:
print(f'Molecule {smiles} on row {i} could not be parsed by RDKit')
Molecule O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3 on row 59 could not be parsed by RDKit
Molecule c1(nc(NC(N)=[NH2])sc1)CSCCNC(=[NH]C#N)NC on row 61 could not be parsed by RDKit
Molecule Cc1nc(sc1)\[NH]=C(\N)N on row 391 could not be parsed by RDKit
Molecule s1cc(CSCCN\C(NC)=[NH]\C#N)nc1\[NH]=C(\N)N on row 614 could not be parsed by RDKit
Molecule c1c(c(ncc1)CSCCN\C(=[NH]\C#N)NCC)Br on row 642 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1ccccc1 on row 645 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N on row 646 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)NC(C)=O on row 647 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N\C(NC)=[NH]\C#N on row 648 could not be parsed by RDKit
Molecule s1cc(nc1\[NH]=C(\N)N)C on row 649 could not be parsed by RDKit
Molecule c1(cc(N\C(=[NH]\c2cccc(c2)CC)C)ccc1)CC on row 685 could not be parsed by RDKit
import random
random.seed(1729)
training_fraction = 0.8
dev_fraction = 0.1
n_examples = len(smiles_records)
n_training_examples = int(n_examples*training_fraction)
n_dev_examples = int(n_examples*dev_fraction)
indices = list(range(n_examples))
random.shuffle(indices) # shuffle is in place
training_indices = indices[:n_training_examples]
dev_indices = indices[n_training_examples:n_training_examples+n_dev_examples]
test_indices = indices[n_training_examples+n_dev_examples:]
training_smiles_records = [smiles_records[i] for i in training_indices]
dev_smiles_records = [smiles_records[i] for i in dev_indices]
test_smiles_records = [smiles_records[i] for i in test_indices]
training_graph_dataset = make_molecular_graph_dataset(training_smiles_records)
dev_graph_dataset = make_molecular_graph_dataset(dev_smiles_records)
test_graph_dataset = make_molecular_graph_dataset(test_smiles_records)
from torch.utils.data import DataLoader
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus
training_dataloader = DataLoader(training_graph_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_dataloader_workers,
collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
class GraphPredictionHeadConfig:
def __init__(self, *, d_model, ffn_dim, pooling_type='sum'):
# Pooling type can be 'sum' or 'mean'
self.d_model = d_model
self.ffn_dim = ffn_dim
self.pooling_type = pooling_type
class GraphPredictionHead(Module):
def __init__(self, input_dim, output_dim, config):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.config = config
self.predictor = Sequential(Linear(self.input_dim, self.config.ffn_dim),
ReLU(),
Linear(self.config.ffn_dim, self.output_dim))
def forward(self, node_features, node_mask):
# The node_features is a tensor of shape (*leading_axises, max_nodes, d_model)
# We want to 'pool' this along the node-axis, which is dim=-2 in pytorch terms
# In this case we assume these features are valid, i.e. that they have been
# masked at a previous step.
if self.config.pooling_type == 'sum':
pooled_nodes = node_features.sum(dim=-2)
elif self.config.pooling_type == 'mean':
# We can't take just the mean along dim=-2, since if this is a batch, some of
# the graphs need to be divided by a smaller number than max_nodes.
# Thankfully we have the information about how many nodes each graph has in
# the node_mask, and since it has 1's and 0's, just summing it along the
# max_nodes axis gives the count of nodes for the corresponding graph
# node_mask has the shape (batch_size, max_nodes), or just (max_nodes,)
# if it's not a batch. We get the count per graph by reducing along the
# last dimension, and by setting keepdims=True, we get a shape
# (batch_size, 1) or (1,) which will allow for broadcasting this with
# division over the summed feature vectors to calculate their mean
node_counts = node_mask.sum(dim=-1)
summed_feature_vectors = node_features.sum(dim=-2)
pooled_nodes = summed_feature_vectors/node_counts
else:
raise ValueError(f'Unsupported pooling type {self.config.pooling_type}')
prediction = self.predictor(pooled_nodes)
return prediction
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print("Device is", device)
Device is cuda
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss, MSELoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error, median_absolute_error
import matplotlib.pyplot as plt
import seaborn as sns
def batch_to_device(batch, device):
moved_batch = {}
for k, v in batch.items():
if torch.is_tensor(v):
v = v.to(device)
moved_batch[k] = v
return moved_batch
class Trainer:
def __init__(self, *, model,
loss_fn, training_dataloader,
dev_dataloader, device=device):
self.model = model
self.training_dataloader = training_dataloader
self.dev_dataloader = dev_dataloader
self.device = device
self.model.to(device)
self.total_epochs = 0
self.optimizer = AdamW(self.model.parameters(), lr=1e-4)
self.loss_fn = loss_fn
def train(self, epochs):
with trange(epochs, desc='Epoch', position=0) as epoch_progress:
batches_per_epoch = len(self.training_dataloader) + len(self.dev_dataloader)
for epoch in epoch_progress:
train_loss = 0
train_n = 0
for i, training_batch in enumerate(tqdm(self.training_dataloader, desc='Training batch', leave=False)):
self.optimizer.zero_grad()
# Move all tensors to the device
self.model.train()
training_batch = batch_to_device(training_batch, self.device)
prediction = self.model(training_batch)
labels = training_batch['labels']
loss = self.loss_fn(prediction.squeeze(), labels) # By default the predictions have shape (batch_size, 1)
loss.backward()
self.optimizer.step()
batch_n = len(labels)
train_loss += batch_n * loss.cpu().item()
train_n += batch_n
#print(f"Training loss for epoch {total_epochs}", train_loss/train_n)
self.total_epochs += 1
dev_predictions = []
dev_labels = []
dev_n = 0
dev_loss = 0
for i, dev_batch in enumerate(tqdm(self.dev_dataloader, desc="Dev batch", leave=False)):
self.model.eval()
with torch.no_grad():
dev_batch = batch_to_device(dev_batch, self.device)
prediction = self.model(dev_batch).squeeze()
dev_predictions.extend(prediction.tolist())
labels = dev_batch['labels']
dev_labels.extend(labels.tolist())
loss = self.loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
batch_n = len(labels)
dev_loss += batch_n*loss.cpu().item()
dev_n += batch_n
epoch_progress.set_description(f"Epoch: train loss {train_loss/train_n: .3f}, dev loss {dev_loss/dev_n: .3f}")
def evaluate_model(trainer, dataloader, label=None, hue_order=[0,1]):
eval_predictions = []
eval_labels = []
eval_loss = 0
eval_n = 0
model = trainer.model
loss_fn = trainer.loss_fn
total_epochs = trainer.total_epochs
for i, eval_batch in enumerate(tqdm(dataloader, desc='batch')):
model.eval()
with torch.no_grad():
eval_batch = batch_to_device(eval_batch, device)
prediction = model(eval_batch).squeeze()
eval_predictions.extend(prediction.tolist())
labels = eval_batch['labels']
eval_labels.extend(labels.tolist())
loss = loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
batch_n = len(labels)
eval_loss += batch_n*loss.cpu().item()
eval_n += batch_n
average_loss = eval_loss/eval_n
roc_auc = roc_auc_score(eval_labels, eval_predictions)
eval_df = pd.DataFrame(data={'target': eval_labels, 'predictions': eval_predictions})
sns.kdeplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
sns.rugplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
if label is not None:
title = f"{label} dataset after {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
else:
title = f"After {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
plt.title(title)
From GNNs to Transformers
We’ve seen how the idea of using sums of vectors can be used to learn things from a graph, we used it on graph neighbourhoods to introduce more structure and we used it as a pooling method for graph predictions. We also saw how the aggregation of a graph neighourhood good be done by multiplying with the adjacency matrix.
We saw how this summing of local neighbourhoods in the graph can be done using a matrix multiplication of the graphs adjacency matrix and the neighbouring node representations
$$ \mathbf{h}i^L = f(\sum{j \in N(i)} \mathbf{h}_j^{L-1}, \mathbf{h}i^{L-1}) $$ $$ \sum{j \in N(i)} \mathbf{h}_j^{L-1} = A_i H^{L-1}$$ $$ H^{L-1} = \begin{bmatrix}\mathbf{h}_1^{L-1}\mathbf{h}_2^{L-1}\vdots\mathbf{h}_n^{L-1}\end{bmatrix} $$
Assuming that the adjacency matrix $A$ is a binary indicator matrix with 1’s for pairs $i,j$ which are connected by an edge, and 0 if they are not connected.
A way to think about the “convolutional” part of a Graph Neural Network is then that it’s essentially performing a matrix multiplication with the adjacency matrix: $$H^L = f(AH^{L-1}, H^{L-1})$$
Where $f$ is a function we wish to learn (e.g. a neural network with two input vectors).
The multiplication $AH^{L-1}$ is what takes the graph structure into account by only aggregating the local neighbourhoods. We’ve seen previously how this method has some fundamental limitations, in particular the “convolution” has a receptive field, so is biased towards learning local patterns.
Dynamically computed “adjacency matrix”
What if instead of assuming the matrix A to be the adjacenct matrix, we calcualte it’s values based on the values of the node pairs and the edge? This would effectively mean that we will potentially aggregate “neighbourhoods” which is the complete graph.
$$A = \begin{bmatrix} g(h_1, h_1, e_{1,1})& g(h_1, h_2, e_{1,2}) & \dots \ \vdots& \ddots & \ g(h_n, h_1, e_{n,1})& \dots & g(h_n, h_n, e_{n,n}) \end{bmatrix}$$
This way we can learn to use node information as well as edge information when deciding on how to aggregate the node set. While it allows us to take the graph structure into account, it also allows the model to learn relationships about more distant nodes.
This is the fundamental idea of the Transformer architecture. By dynamically setting the values of this “adjacency matrix”, we can have a model which can learn to induce structure from arbitrary sets.
Depending on how we construct $g$, we can choose to inject knowledge about the elements of the input set $H$ such as relative position between tokens if H is actually a sequence, or information about a particular pair, such as whether they are connected by an edge in a graph.
The downside
While this idea is very powerful, it comes with a major limitation. A GNN typically only implictly muiltiply the node vectors $H$ with the adjecency matrix. In practice this is implemented by a sparse operation as we saw in the previous notebook. This means that the computational cost is $O(n d k)$, where $n$ is the number of nodes and $d$ the average (TODO, maximum?) degree and $k$ the dimensionality of the node vectors.
To multiply with a dense matrix this instead becomes $O(n^2 k)$ and this quadratic scaling on the number of nodes in the input severely limits the application of this idea.
Generating $A$ also scales quadratically with the number of nodes, since we need to apply the function to all pairs of node vectors, regardless of edges.
This fact has not stopped this idea of becoming wildly successful in the domain of Natural Language Processing, and as long as the domain we’re working has relatively small inputs (like organic molecules in medicinal chemistry) we can handle the quadratic scaling with input with brute force (a lot of computational capacity).
Implementing a Transformer
Below is the implementation of the transformer. In aparticular the BasicTransformerLayer
is where the main difference from a GNN is implemented. The BasicTransformerEncoder
is very similar to the GNN encoder.
import math
import torch
from torch.nn import Module, Embedding, ModuleList, Linear, Sequential, ReLU, LayerNorm, Dropout
from torch.nn.functional import softmax
class BasicTransformerConfig:
def __init__(self, *,
d_model: int,
n_layers: int,
ffn_dim: int,
head_dim: int,
layer_normalization: bool = True,
dropout_rate: float = 0.1,
residual_connections: bool=True):
self.d_model = d_model
self.n_layers = n_layers
self.ffn_dim = ffn_dim
# Note that we introduce a new hyper parameter called *head_dim*, in
# Transformers we typically transform the "node" feature vectors to
# a lower dimensional space
self.head_dim = head_dim
self.layer_normalization = layer_normalization
self.dropout_rate = dropout_rate
self.residual_connections = residual_connections
class BasicTransformerLayer(Module):
def __init__(self, config):
super().__init__()
self.config = config
self.input_dim = config.d_model
self.output_dim = config.d_model
self.ffn_dim = config.ffn_dim
self.head_dim = config.head_dim
# Transformers typically don't use mlps to create the neighbours and center
# embeddings, instead relying on just linear transformations
self.neighbour_transform = Linear(self.input_dim, self.head_dim, bias=False)
self.center_transform = Linear(self.input_dim, self.head_dim, bias=False)
# The transformer uses layer normalization by default
self.attention_norm = LayerNorm(self.input_dim)
self.output_transform = Sequential(Linear(self.input_dim, self.ffn_dim),
ReLU(),
Linear(self.ffn_dim, self.output_dim))
self.output_norm = LayerNorm(self.output_dim)
self.dropout = Dropout(self.config.dropout_rate)
self.scaling_factor = math.sqrt(self.input_dim)
def attention_function(self, adjacency_matrix, center_node_features,
neighbour_node_features, edge_features, node_mask, edge_mask):
#The standard Transformer just
# take the dot product between the center node and the neighbour nodes,
# scaled by the square root of the model dimension.
# To take the dot products between all center nodes and all neighbour nodes
# We perform an outer product by first transposing one of the matrices along
# the last two axises
# In the single-graph case the matrix multplication between
# a matrix with shape (n_nodes, head_dim) times (head_dim, n_nodes)
# gives the resulting matrix of shape (n_nodes, n_nodes), where each element
# is the dot product of the column for a node in one first with the row of a
# node in the other
attention_logits = torch.matmul(center_node_features,
neighbour_node_features.transpose(-1, -2))/self.scaling_factor
# The "adjcency matrix" of the transformer is actually using weighted means
# for the aggregation, and the way we achieve this is to make sure that the
# rows of the matrix are normalized using softmax
# However, if we have a batch of graphs as inputs, we will have some
# "positions" in the node features which we should not include in our
# aggregation, and should therefore mask out in our attention matrix.
# If we did that after the softmax calculations, the rows would no longer
# add up to 1. Instead we do it before the softmax by essentially
# setting the masked values to va number so low it will end up as a
# 0 in the softmax output. The lowest value we can imagine is negative
# infinity, so let's use that.
# The goal is to mask out parts of the different batch attention_logits which
# are not part of the nodes, so essentially have a resulting matrix per batch
# example which looks something like
# [[ a, b, c, -inf, -inf],
# [ d, e, f, -inf, -inf],
# [ g, h, i, -inf, -inf],
# [-inf, -inf, -inf, -inf, -inf],
# [-inf, -inf, -inf, -inf, -inf],
# ]
# To do this we will first create a boolean mask which have True in the
# places we want to fill with '-inf'. We do this by using the node mask
# like in the GNN examples, doing something very much like an outer product
nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
# The nodemask has 1's where there are valid elements and 0's where there
# are none, we invert his and convert it to bool tensor
fill_mask = (1 - nodemask_2d).to(torch.bool)
# We're now ready to 'mask out' the logits. Notice that masked_fill_ is
# in place
attention_logits.masked_fill_(fill_mask, float('-inf'))
attention_matrix = softmax(attention_logits, dim=-1)
# There will be rows of the smaller attention matrices which where filled
# completely with -inf values, these will now be rows of 'nan' values
# We perform a new fill of the attention matrix, but this time with 0s
attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
return attention_matrix
def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
# In this basic Transformer layer we'll not use the edge features, and instead
# focus on the basic transformer formulation of this problem. This will pretty
# much treat the graph as just a node set. We should not expect this to
# be able to anything which relies on the graph structure
center_node_features = self.center_transform(node_features)
neighbour_node_features = self.neighbour_transform(node_features)
# The transformed node features are either a 3-tensor
# (batch_size, max_nodes, head_dim) or a matrix (n_nodes, head_dim) when
# it's a single graph. We'll make this code agnostic to that
# The goal now is to _compute_ an "adjacency matrix" using the node features
# This could be done by any function. We define a method on this class
# which is the attention function which has the purpose of giving us an
# "adjacency matrix"
attention_matrix = self.attention_function(adjacency_matrix,
center_node_features,
neighbour_node_features,
edge_features,
node_mask,
edge_mask)
# Now we aggregate the neighbourhoods using the attention matrix (our computed "adjacency matrix")
# The transformer doesn't transform the node features at this stage, instead doing it in a separate MLP after residual connections and aggregation
aggregated_neighbourhoods = torch.matmul(attention_matrix, node_features)
# and mask the result
masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1)
# Followed by a dropout and layer normalization
masked_features = self.dropout(masked_features)
masked_features = self.attention_norm(masked_features)
# The transformer by default uses residual connections
if self.config.residual_connections:
masked_features = masked_features + node_features
# Now we apply the output transform as in the GNN
updated_node_features = self.output_transform(masked_features)
# Mask again
updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)
# Followed by a dropout and normalization
updated_node_features = self.dropout(updated_node_features)
updated_node_features = self.output_norm(updated_node_features)
# And the resiudal connection from the input to the output MLP
if self.config.residual_connections:
updated_node_features = updated_node_features + masked_features
return updated_node_features
class BasicTransformerEncoder(torch.nn.Module):
def __init__(self, *,
config: BasicTransformerConfig,
continuous_node_variables=None,
categorical_node_variables=None,
continuous_edge_variables=None,
categorical_edge_variables=None,
layer_type=BasicTransformerLayer):
super().__init__()
self.config = config
self.layer_type = layer_type
self.continuous_node_variables = continuous_node_variables
self.categorical_node_variables = categorical_node_variables
self.continuous_edge_variables = continuous_edge_variables
self.categorical_edge_variables = categorical_edge_variables
# We want the embeddings together with the continuous values to be of dimension d_model, therefore the allocate d_model - len(continuous_variables) as the embeddings dim
self.categorical_node_embeddings_dim = config.d_model - len(self.continuous_node_variables)
self.categorical_edge_embeddings_dim = config.d_model - len(self.continuous_edge_variables)
self.node_featurizer = FeatureCombiner(self.categorical_node_variables,
self.categorical_node_embeddings_dim)
self.edge_featurizer = FeatureCombiner(self.categorical_edge_variables,
self.categorical_edge_embeddings_dim)
# Notice that we use the supplied layer type above when creating the graph
# layers. This allows us to easily change the kind of graph layers
# we use later on
self.graph_layers = ModuleList([layer_type(config) for l in range(config.n_layers)])
def forward(self, batch):
# First order of business is to embed the node embeddings
node_mask = batch['nodes_mask']
batch_size, max_nodes = node_mask.shape
continuous_node_features = batch['continuous_node_features']
categorical_node_features = batch['categorical_node_features']
node_features = self.node_featurizer(continuous_node_features, categorical_node_features)
masked_node_features = node_features * node_mask.unsqueeze(-1)
continuous_edge_features = batch['continuous_edge_features']
categorical_edge_features = batch['categorical_edge_features']
edge_features = self.edge_featurizer(continuous_edge_features, categorical_edge_features)
edge_mask = batch['edge_mask']
masked_edge_features = edge_features * edge_mask.unsqueeze(-1)
# We have now embedded the node features, we'll propagate them through our
# graph layers
adjacency_matrix = batch['adjacency_matrices']
memory_state = masked_node_features
for l in self.graph_layers:
memory_state = l(adjacency_matrix, memory_state, masked_edge_features , node_mask, edge_mask)
return memory_state
class GraphPredictionNeuralNetwork(Module):
def __init__(self, encoder, prediction_head):
super().__init__()
self.encoder = encoder
self.prediction_head = prediction_head
def forward(self, batch):
encoded_graph = self.encoder(batch)
prediction = self.prediction_head(encoded_graph, batch['nodes_mask'])
return prediction
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model,
n_layers=2,
ffn_dim=16,
head_dim=8,
layer_normalization=True,
dropout_rate=0.1,
residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config,
continuous_node_variables=dataset.continuous_node_variables,
categorical_node_variables=dataset.categorical_node_variables,
continuous_edge_variables=dataset.continuous_edge_variables,
categorical_edge_variables=dataset.categorical_edge_variables)
head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)
model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_dataloader,
dev_dataloader=dev_dataloader)
trainer.train(1)
Adding the graph structure to Transformers
In basic GNNs, we take the graph strucuture into account by using the given adjacency matrix. In the Transformer we’ve seen now, we compute an attention matrix, a matrix which plays the same role as the adjacency matrix of the basic graph neural network.
We can simply modify our attention function of the Transformer to also take the adjacency matrix into account.
This attention function could really be anything, but for now we’ll stick with a very simple idea of modifying the logits of our scaled dot-product attention by adding a scalar to the logits if there is a 1 in the adjacency matrix for that place.
If $a_{i,j}$ is the the value before the softmax of the attention logit matrix for how much we should include information from node $j$ when aggregating for node $i$, we define a function:
$$ a_{i,j} = f(\mathbf{x_i}, \mathbf{x_j}) = \frac{<\mathbf{x_i}, \mathbf{x_j}>}{\sqrt{\text{d_model}}} + w\mathbf{1}_{{i, j} \in E} $$
Here ${1}_{{i, j} \in E}$ is the indicator function taking the value $1$ if there is an edge between $i$ and $j$ (in other words, the $i,j$ entry of the adjacency matrix).
We also add a learnable scalar $w$ which the network can learn to set to change what influence the prescence of an edge has on the attention score.
from torch.nn import Parameter
class AdjacencyTransformerLayer(BasicTransformerLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.adjacency_weight = Parameter(torch.tensor(1.))
def attention_function(self, adjacency_matrix, center_node_features,
neighbour_node_features, edge_features, node_mask, edge_mask):
# We still take the
dot_product_logits = torch.matmul(center_node_features,
neighbour_node_features.transpose(-1, -2))/self.scaling_factor
# The adjacency_matrix is scaled by a parameter so that the network can
# learn to adjust the influence of edge presence
adjacency_logits = self.adjacency_weight*adjacency_matrix
attention_logits = dot_product_logits + adjacency_logits
nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
fill_mask = (1 - nodemask_2d).to(torch.bool)
attention_logits.masked_fill_(fill_mask, float('-inf'))
attention_matrix = softmax(attention_logits, dim=-1)
# There will be rows of the smaller attention matrices which where filled
# completely with -inf values, these will now be rows of 'nan' values
# We perform a new fill of the attention matrix, but this time with 0s
attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
return attention_matrix
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model,
n_layers=2,
ffn_dim=16,
head_dim=8,
layer_normalization=True,
dropout_rate=0.1,
residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config,
continuous_node_variables=dataset.continuous_node_variables,
categorical_node_variables=dataset.categorical_node_variables,
continuous_edge_variables=dataset.continuous_edge_variables,
categorical_edge_variables=dataset.categorical_edge_variables,
layer_type=AdjacencyTransformerLayer)
config = GraphPredictionHeadConfig(d_model=32, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=head_config.d_model, output_dim=1, config=head_config)
model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_dataloader,
dev_dataloader=dev_dataloader)
trainer.train(1)
Edge features in the attention
In the previous example, we only used the adjacency matrix to influence the attention function. Now we’ll extend this to actually include edge features as well.
To do this we simply extend this notion of using a function to compute the values of the attention matrix, in particular we will use function which takes the feature nodes for the center and neighbour node as well as the edge feature vector between these nodes. For nodes which have no feature vector, the zero-vector will be used.
This function will be implemented by a simple 2 layer MLP. And we choose to implement it as below
$$ a_{i,j} = f(\mathbf{x_i}, \mathbf{x_j}, \mathbf{x_{e_{i,j}}}) = W_2 \sigma(W_1 (\mathbf{x_i} + \mathbf{x_j} + \mathbf{x_{e_{i,j}}}) + \mathbf{b_1}) + \mathbf{b_2} $$
You can see that we apply the neural network on the sum the vectors of node features and edge features. This will make this function permutation invariant of its inputs (i.e. $f(\mathbf{x_i}, \mathbf{x_j}, \mathbf{x_{e_{i,j}}}) = f( \mathbf{x_j}, \mathbf{x_{e_{i,j}}}, \mathbf{x_i},)$. If this is not desired, concatenation can be used instead but makes the implementation a bit more complex.
class EdgeAttributesAttentionTransformerLayer(BasicTransformerLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# We're going to sum our center and neighbour vectors to the edge feature
# vector, so have to be mindful of the dimensionality
self.center_transform = Linear(self.input_dim, self.input_dim)
self.neighbour_transform = Linear(self.input_dim, self.input_dim)
self.attention_score_function = Sequential(Linear(self.input_dim,
self.config.ffn_dim),
ReLU(),
Linear(self.config.ffn_dim,
1))
def attention_function(self, adjacency_matrix, center_node_features,
neighbour_node_features, edge_features,
node_mask, edge_mask):
# Our goal is to first build up the input tensor to our simple MLP.
# This will be a tensor of shape
# (*leading_dimension, n_nodes, n_nodes, feature_dim)
# The edge_feature's tensor already has this shape, and we've made sure
# the edge features on that tensor is already zeroed for places where
# there are no edges.
# We want each element [...,i,j,:] of this matrix to be
# center_node_features[i] + neighbour_node_features[j] + edge_features[i,j]
# We achieve this by broadcasting once again, the center node features
# are broadcasted along the -3'rd axis and the neighbour node features
# along the -2'nd axis which gives us the desired result
# This is one of the most annyoing things with the frameworks we use,
# having to be very conscious about the order of axises
# unsqueeze(-2) broadcasts along the "row" of this batch of edge feature "matrices"
attention_score_input = edge_features + center_node_features.unsqueeze(dim=-2)
# unsqueeze(-3) broadcasts along the "columns" if the 3-tensor
attention_score_input = attention_score_input + neighbour_node_features.unsqueeze(dim=-3)
attention_logits = self.attention_score_function(attention_score_input).squeeze()
# We need to perform the same masking as before
nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
fill_mask = (1 - nodemask_2d).to(torch.bool)
attention_logits.masked_fill_(fill_mask, float('-inf'))
attention_matrix = softmax(attention_logits, dim=-1)
attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
return attention_matrix
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model,
n_layers=2,
ffn_dim=16,
head_dim=8,
layer_normalization=True,
dropout_rate=0.1,
residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config,
continuous_node_variables=dataset.continuous_node_variables,
categorical_node_variables=dataset.categorical_node_variables,
continuous_edge_variables=dataset.continuous_edge_variables,
categorical_edge_variables=dataset.categorical_edge_variables,
layer_type=EdgeAttributesAttentionTransformerLayer)
head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)
model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_dataloader,
dev_dataloader=dev_dataloader)
trainer.train(1)
Edge features in the transformations
While the basic Transformer can include pairwise information (such as the adjacency matrix) in its attention function, this is only used to decide what node vectors to aggregate, i.e. the values of the attention matrix (our dynamic adjacency matrix).
In the GNN we looked at in the last notebook, edge feature vectors where integrated into the message “passed” from the neighbour nodes, and could be used by the MLP parts of the network to learn a combined representation of the node features and edge features and eventually directly part of the output.
We can extend the Transformer architecture to also do this just like we did in the example on GNNs. We will still use the underlying idea of computing an adjacency matrix to have a dynamic aggregation function, but instead making the aggregated neighbourhood be contexually different based on what node we’re currently aggregating for.
from torch.nn import Parameter
class EdgeAttributesTransformerLayer(EdgeAttributesAttentionTransformerLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
center_node_features = self.center_transform(node_features)
neighbour_node_features = self.center_transform(node_features)
attention_matrix = self.attention_function(adjacency_matrix,
center_node_features,
neighbour_node_features,
edge_features,
node_mask,
edge_mask)
# Just like when we computed the attention scores using the edge feature
# vectors we also want to create context-dependent neighbourhoods now
# We do that by creating the full pairwise tensor of shape
# (*leading_dimensions, n_nodes, n_nodes, num_features). We use the same
# procedure with broadcasting in the
# EdgeAttributesAttentionTransformerLayer attention_function above
# unsqueeze(-2) broadcasts along the "row" of this batch of edge feature "matrices"
context_dependent_features = edge_features + center_node_features.unsqueeze(dim=-2)
# unsqueeze(-3) broadcasts along the "columns" if the 3-tensor
context_dependent_features = context_dependent_features + neighbour_node_features.unsqueeze(dim=-3)
# Now the aggregation becomes a bit more tricky. We did this in the GNN layer
# which used edge features as well, essentially explicitly performing what
# used a matrix multiplaction to do previously: broadcast the attention matrix
# over the feature dimension: multiplying a feature vector at position i,j
# with the value in the attention matrix at position i,j
attended_features = context_dependent_features * attention_matrix.unsqueeze(dim=-1)
# Now the aggregation is performed by summing these attended features along
# the "rows", i.e. reducing away the "column" axis which is dim=-2
aggregated_neighbourhoods = attended_features.sum(dim=-2)
# Now maske the result as before
masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1)
# Followed by a dropout, layer normalization and residual sum
masked_features = self.dropout(masked_features)
masked_features = self.attention_norm(masked_features)
if self.config.residual_connections:
masked_features = masked_features + node_features
# Transform the features with our MLP
updated_node_features = self.output_transform(masked_features)
updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)
# Followed by a dropout and normalization
updated_node_features = self.dropout(updated_node_features)
updated_node_features = self.output_norm(updated_node_features)
# And the resiudal connection from the input to the output MLP
if self.config.residual_connections:
updated_node_features = updated_node_features + masked_features
return updated_node_features
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model,
n_layers=2,
ffn_dim=16,
head_dim=8,
layer_normalization=True,
dropout_rate=0.1,
residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config,
continuous_node_variables=dataset.continuous_node_variables,
categorical_node_variables=dataset.categorical_node_variables,
continuous_edge_variables=dataset.continuous_edge_variables,
categorical_edge_variables=dataset.categorical_edge_variables,
layer_type=EdgeAttributesTransformerLayer)
head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)
model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_dataloader,
dev_dataloader=dev_dataloader)
trainer.train(1)
Pairwise features
In the Graph Neural Network framework, we tend to focus on node attributes and edge attributes. With the Transformer framework we have instead an architecture which is based on all pairwise interaction between nodes. This allows us to extend the idea of what edge features could be.
Instead of only specifying features for pairs of nodes which have an edge, we can think of features between any pair of nodes. This can be used to give the network information about node relationships which it struggles to learn, but can easily compute with regular algorithms, such as distance between nodes if they are embedded in an euclidean space, or the shortest path between them if they are in a graph.
This latter idea is the idea of a path-augmented graph transformer Chen Benson et al. “Path-augmented graph transformer network.”.
We can use the architecture we’ve already defined above but allow our edge features to actually be pairwise features. Instead of only having features for pairs of nodes which are directly connected to the graph, we introduce features between any pair of nodes.
In this case the pairwise features will contain information about the edges along the shortest path of our graph, to at most max_path_length
steps (longer paths are just truncated).
Representing the path information
In the graph, the shortest path is defined by the edges between two nodes. A natural way of representing the path is then a sequence of the edge features along this shortest path.
To handle the sequence information of these edges we could aggregate the path information using a sequence model such as a Recurrent Neural Network or a Transformer. In this example we make things a bit simpler and just represent the path features as separate categorical variables. We’ll add these variables between all pairs of nodes.
Since we have to decide on exactly what categorical variables to use when defining the network architecture, we decide on some max length $k$ and only take edge features along the path up to that number.
max_path_length = 3 # This is the k we limit our path lengths to
from itertools import combinations
from rdkit.Chem.rdmolops import GetShortestPath
from itertools import combinations
from tqdm.notebook import tqdm
PAIRWISE_FEATURES = []
BOND_FEATURES = {'type_feature': TYPE_FEATURE,
'direction_feature': DIRECTION_FEATURE,
'aromatic_feature': AROMATIC_FEATURE,
'stereo_feature': STEREO_FEATURE}
PAIRWISE_FEATURES.extend(BOND_FEATURES.values())
# We careate a copy of the bond features for each path step
PATH_FEATURES = []
for i in range(max_path_length):
path_vars = {}
for feature_kw, var in BOND_FEATURES.items():
name = var.name
path_var_name = f"{name}_p{i}"
if isinstance(var, ContinuousVariable):
path_var = ContinuousVariable(path_var_name)
if isinstance(var, CategoricalVariable):
path_var_values = var.values
path_var = CategoricalVariable(path_var_name, path_var_values, add_null_value=False)
path_vars[feature_kw] = path_var
PAIRWISE_FEATURES.append(path_var)
PATH_FEATURES.append(path_vars)
def get_shortest_paths_bond_features(rd_bond,
*,
type_feature,
direction_feature,
aromatic_feature,
stereo_feature):
if rd_bond is not None:
bond_type = str(rd_bond.GetBondType())
bond_stereo_info = str(rd_bond.GetStereo())
bond_direction = str(rd_bond.GetBondDir())
is_aromatic = rd_bond.GetIsAromatic()
else:
bond_type = None
bond_stereo_info = None
bond_direction = None
is_aromatic = None
return {type_feature: bond_type,
direction_feature: bond_direction,
aromatic_feature: is_aromatic,
stereo_feature: bond_stereo_info}
def get_pairwise_features(rd_mol, rd_atom_a, rd_atom_b):
pairwise_features = {}
# First we create the features for the bond (or missing such) between
# the two atoms
bond = rd_mol.GetBondBetweenAtoms(rd_atom_a.GetIdx(), rd_atom_b.GetIdx())
bond_features = get_shortest_paths_bond_features(bond, **BOND_FEATURES)
pairwise_features.update(bond_features)
# Now we create bond features for the path between rd_atom_a and rd_atom_b
# We iterate over atoms of the shortest path up till max_path_length
# If the shortest path is shorter than max_path_length, we add None-valued
# features for the remaining ones
shortest_path = GetShortestPath(rd_mol, rd_atom_a.GetIdx(), rd_atom_b.GetIdx())
for i in range(max_path_length):
path_bond_variables = PATH_FEATURES[i]
if i < (len(shortest_path) - 1):
a, b = shortest_path[i], shortest_path[i+1]
path_bond = rd_mol.GetBondBetweenAtoms(a, b)
else:
path_bond = None
path_bond_features = get_shortest_paths_bond_features(path_bond, **path_bond_variables)
pairwise_features.update(path_bond_features)
return pairwise_features
rd_mol = MolFromSmiles('CCCCCC')
atom_a = rd_mol.GetAtomWithIdx(0)
atom_b = rd_mol.GetAtomWithIdx(1)
get_pairwise_features(rd_mol, atom_a, atom_b)
{<CategoricalVariable: bond_direction>: 'NONE',
<CategoricalVariable: bond_direction_p0>: 'NONE',
<CategoricalVariable: bond_direction_p1>: None,
<CategoricalVariable: bond_direction_p2>: None,
<CategoricalVariable: bond_stereo>: 'STEREONONE',
<CategoricalVariable: bond_stereo_p0>: 'STEREONONE',
<CategoricalVariable: bond_stereo_p1>: None,
<CategoricalVariable: bond_stereo_p2>: None,
<CategoricalVariable: bond_type>: 'SINGLE',
<CategoricalVariable: bond_type_p0>: 'SINGLE',
<CategoricalVariable: bond_type_p1>: None,
<CategoricalVariable: bond_type_p2>: None,
<CategoricalVariable: is_aromatic>: False,
<CategoricalVariable: is_aromatic_p0>: False,
<CategoricalVariable: is_aromatic_p1>: None,
<CategoricalVariable: is_aromatic_p2>: None}
atom_a = rd_mol.GetAtomWithIdx(0)
atom_b = rd_mol.GetAtomWithIdx(5)
get_pairwise_features(rd_mol, atom_a, atom_b)
{<CategoricalVariable: bond_direction>: None,
<CategoricalVariable: bond_direction_p0>: 'NONE',
<CategoricalVariable: bond_direction_p1>: 'NONE',
<CategoricalVariable: bond_direction_p2>: 'NONE',
<CategoricalVariable: bond_stereo>: None,
<CategoricalVariable: bond_stereo_p0>: 'STEREONONE',
<CategoricalVariable: bond_stereo_p1>: 'STEREONONE',
<CategoricalVariable: bond_stereo_p2>: 'STEREONONE',
<CategoricalVariable: bond_type>: None,
<CategoricalVariable: bond_type_p0>: 'SINGLE',
<CategoricalVariable: bond_type_p1>: 'SINGLE',
<CategoricalVariable: bond_type_p2>: 'SINGLE',
<CategoricalVariable: is_aromatic>: None,
<CategoricalVariable: is_aromatic_p0>: False,
<CategoricalVariable: is_aromatic_p1>: False,
<CategoricalVariable: is_aromatic_p2>: False}
def rdmol_to_complete_graph(mol):
atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
all_pairwise_features = {}
for atom_a, atom_b in combinations(mol.GetAtoms(), 2):
all_pairwise_features[frozenset((atom_a.GetIdx(), atom_b.GetIdx()))] = get_pairwise_features(mol, atom_a, atom_b)
return atoms, all_pairwise_features
def smiles_to_complete_graph(smiles):
rd_mol = MolFromSmiles(smiles)
graph = rdmol_to_complete_graph(rd_mol)
return graph
import multiprocessing
from tqdm.notebook import tqdm
def process_smiles_record(smiles_record):
smiles = smiles_record['smiles']
rdmol = MolFromSmiles(smiles)
label = smiles_record['label']
graph = rdmol_to_complete_graph(rdmol)
return graph, label, smiles_record
def make_graph_shortest_path_dataset(smiles_records,
atom_features=ATOM_FEATURES,
bond_features=PAIRWISE_FEATURES):
'''
Create a new GraphDataset from a list of smiles_records dictionaries.
These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
The 'label' record will be ignored, and instead be replaced by the diameter of the graph.
'''
graphs = []
labels = []
metadata = []
with multiprocessing.Pool() as pool:
for graph, label, smiles_record in tqdm(pool.imap(process_smiles_record, smiles_records), total=len(smiles_records)):
graphs.append(graph)
labels.append(label)
metadata.append(smiles_record)
return GraphDataset(graphs=graphs,
labels=labels,
node_variables=atom_features,
edge_variables=bond_features,
metadata=metadata)
dataset = make_graph_shortest_path_dataset([{'smiles': 'c1ccccc1', 'label':1},{'smiles':'OS(=O)(=O)O', 'label': 0}])
training_shortest_path_dataset = make_graph_shortest_path_dataset(training_smiles_records)
dev_shortest_path_dataset = make_graph_shortest_path_dataset(dev_smiles_records)
test_shortest_path_dataset = make_graph_shortest_path_dataset(test_smiles_records)
from torch.utils.data import DataLoader
batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus
training_shortest_path_dataloader = DataLoader(training_shortest_path_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_dataloader_workers,
collate_fn=collate_graph_batch)
dev_shortest_path_dataloader = DataLoader(dev_shortest_path_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
test_shortest_path_dataloader = DataLoader(test_shortest_path_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_dataloader_workers,
drop_last=False,
collate_fn=collate_graph_batch)
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model,
n_layers=2,
ffn_dim=16,
head_dim=8,
layer_normalization=True,
dropout_rate=0.1,
residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config,
continuous_node_variables=dataset.continuous_node_variables,
categorical_node_variables=dataset.categorical_node_variables,
continuous_edge_variables=dataset.continuous_edge_variables,
categorical_edge_variables=dataset.categorical_edge_variables,
layer_type=EdgeAttributesTransformerLayer)
head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)
model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)
loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model,
loss_fn=loss_fn,
training_dataloader=training_shortest_path_dataloader,
dev_dataloader=dev_shortest_path_dataloader)
Slow training
You will notice that this model is much slower to train than the previous. That’s not because of changes to the neural network (you can see that it’s exactly the same as before) but because the input is now much more complex. Each edge now has max_path_length
as many features as before. This makes the dataloading and embedding steps more time consuming.
In particular, the Colab instances we’re using only have 2 processors, so we can’t get much help from doing batch preprocessing in parallel. When training on proper hardware, this time would be hidden due to multiprocessing.
trainer.train(1)
A note on efficiency
We’ve seen how we can effectively extend our graph neural network to not only use the neighbourhood of aggregation, and this allows us to embed a lot of domain knowledge into how we represent our problems.
The downside is that this method scales poorly with the size of our graphs. Both computation time and memory demand will scale quadratically with the size of the input.
With a regular GNN, we can make the neighbourhood aggregation sparse, and this is what practical GNN frameworks such as PyTorch Geometric do.
In the field of NLP, much research is dedicated to making the transformer architecture more efficient and we will likely see developments on how this method can be more efficient in the future.
Task
Experiment with this Path-augmented Transformer. Can you do better than the GNN from the previous notebook?
Learning outcomes
In this notebook we took a deep dive into Transformers. As you can see, we’ve essentially extended the idea of a graph neural network by computing an “adjacency matrix” (the attention matrix). The choice of what function to use to compute each element of this matrix is up to us, but for graphs it’s reasonable to use any edge information between the pair of nodes.
We saw how we could also include any pairwise information, and while we used the shortes path between the nodes in this case, for graphs embedded in a euclidean space it could have been the euclidean distance or any other information we might compute from a pair of node features.
Important concepts
Attention
Function on pair of nodes
Shortest path
What about this multi-head attention?
When talking about Transformers, a lot of time is often given to the multi-head self-attention. In this notebook we haven’t covered that, mostly because the difference between multi-head and single head self-attention is conceptually small, but quite a bit more messy to implement efficiently.
What is multi-head self-attention?
In multi-head, we effectively add parallel networks in the aggregation part. So think of it as having multiple parallel GNN layers, each performing their own MSG and AGG step. In the end, the resulting vectors of the parallel heads are combined (concatenated in the Transformer) to form the new vector for the node.
This is equivalent to using “block-diagonal” weight matrices (as well as some shuffling around of results) which is an interesting way of using sparse computation for neural networks.