GNNs to Transformers

We’ve previously seen how we can implement graph convolutions using multiplication with the adjacency matrix. In this notebook we will look at how this can be extended to a Transformer by computing an adjacency matrix.

You can find this notebook on Google Colab here

Code from previous notebooks

We will use the same example as in the previous notebooks.

from collections import defaultdict
from collections.abc import Set

import rdkit
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG=True  #< set this to False if you want PNGs instead of SVGs
IPythonConsole.drawOptions.addAtomIndices = True  # This will help when looking at the Mol graph representation
IPythonConsole.molSize = 600, 600

# We supress RDKit errors for this notebook
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')  


import torch
from torch.utils.data import Dataset, DataLoader

float_type = torch.float32  # We're hardcoding types in the tensors further down
categorical_type = torch.long
mask_type = torch.float32  # We're going to be multiplying our internal calculations with a mask using this type
labels_type = torch.float32 # We're going to use BCEWithLogitsLoss, which expects the labels to be of the same type as the predictions
class ContinuousVariable:
  def __init__(self, name):
    self.name = name

  def __repr__(self):
    return f'<ContinuousVariable: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name

  def __hash__(self):
    return hash(self.name)

class CategoricalVariable:
  def __init__(self, name, values, add_null_value=True):
    self.name = name
    self.has_null_value = add_null_value
    if self.has_null_value:
      self.null_value = None
      values = (None,) + tuple(values)
    self.values = tuple(values)
    self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
    self.inv_value_to_idx_mapping = {i: v for v, i in self.value_to_idx_mapping.items()}
    
    if self.has_null_value:
      self.null_value_idx = self.value_to_idx_mapping[self.null_value]
  
  def get_null_idx(self):
    if self.has_null_value:
      return self.null_value_idx
    else:
      raise RuntimeError(f"Categorical variable {self.name} has no null value")

  def value_to_idx(self, value):
    return self.value_to_idx_mapping[value]
  
  def idx_to_value(self, idx):
    return self.inv_value_to_idx_mapping[idx]
  
  def __len__(self):
    return len(self.values)
  
  def __repr__(self):
    return f'<CategoricalVariable: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name and self.values == other.values

  def __hash__(self):
    return hash((self.name, self.values))
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 
                'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 
                'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 
                'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 
                'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 
                'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 
                'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg', 
                'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce', 'Pr', 
                'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 
                'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 
                'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']
ATOM_SYMBOLS_FEATURE = CategoricalVariable('atom_symbol', ATOM_SYMBOLS)

ATOM_AROMATIC_VALUES = [True, False]
ATOM_AROMATIC_FEATURE = CategoricalVariable('is_aromatic', ATOM_AROMATIC_VALUES)

# In practice you might like to use categroical features for valence, but we use continuous here for demonstration
ATOM_EXPLICIT_VALENCE_FEATURE = ContinuousVariable('explicit_valence')

ATOM_IMPLICIT_VALENCE_FEATURE = ContinuousVariable('implicit_valence')

ATOM_FEATURES = [ATOM_SYMBOLS_FEATURE, ATOM_AROMATIC_FEATURE, ATOM_EXPLICIT_VALENCE_FEATURE, ATOM_IMPLICIT_VALENCE_FEATURE]

def get_atom_features(rd_atom):
  atom_symbol = rd_atom.GetSymbol()
  is_aromatic = rd_atom.GetIsAromatic()
  implicit_valence = float(rd_atom.GetImplicitValence())
  explicit_valence = float(rd_atom.GetExplicitValence())
  return {ATOM_SYMBOLS_FEATURE: atom_symbol,
          ATOM_AROMATIC_FEATURE: is_aromatic,
          ATOM_EXPLICIT_VALENCE_FEATURE: explicit_valence,
          ATOM_IMPLICIT_VALENCE_FEATURE: implicit_valence}
# We could use the RDKit enumeration types instead of strings, but the advantage
# of doing it like this is that our representation becomes independent of RDKit
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE', 
              'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
              'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC', 
              'IONIC', 'HYDROGEN', 'THREECENTER',	'DATIVEONE', 'DATIVE',
              'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']
TYPE_FEATURE = CategoricalVariable('bond_type', BOND_TYPES)

BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT', 'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalVariable('bond_direction', BOND_DIRECTIONS)

BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE', 
               'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalVariable('bond_stereo', BOND_STEREO)

AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalVariable('is_aromatic', AROMATIC_VALUES)

BOND_FEATURES = [TYPE_FEATURE, DIRECTION_FEATURE, AROMATIC_FEATURE, STEREO_FEATURE]

def get_bond_features(rd_bond):
  bond_type = str(rd_bond.GetBondType())
  bond_stereo_info = str(rd_bond.GetStereo())
  bond_direction = str(rd_bond.GetBondDir())
  is_aromatic = rd_bond.GetIsAromatic()
  return {TYPE_FEATURE: bond_type,
          DIRECTION_FEATURE: bond_direction,
          AROMATIC_FEATURE: is_aromatic,
          STEREO_FEATURE: bond_stereo_info}
def rdmol_to_graph(mol):
  atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
  bonds = {frozenset((rd_bond.GetBeginAtomIdx(), rd_bond.GetEndAtomIdx())): get_bond_features(rd_bond) for rd_bond in mol.GetBonds()}
  return atoms, bonds
def smiles_to_graph(smiles):
  rd_mol = MolFromSmiles(smiles)
  graph = rdmol_to_graph(rd_mol)
  return graph
g = smiles_to_graph('c1ccccc1')
class GraphDataset(Dataset):
  def __init__(self, *, graphs, labels, node_variables, edge_variables, metadata=None):
    '''
    Create a new graph dataset, 
    '''
    self.graphs = graphs
    self.labels = labels
    assert len(self.graphs) == len(self.labels), "The graphs and labels lists must be the same length"
    self.metadata = metadata
    if self.metadata is not None:
      assert len(self.metadata) == len(self.graphs), "The metadata list needs to be as long as the graphs"
    self.node_variables = node_variables
    self.edge_variables = edge_variables
    self.categorical_node_variables = [var for var in self.node_variables if isinstance(var, CategoricalVariable)]
    self.continuous_node_variables = [var for var in self.node_variables if isinstance(var, ContinuousVariable)]
    self.categorical_edge_variables = [var for var in self.edge_variables if isinstance(var, CategoricalVariable)]
    self.continuous_edge_variables = [var for var in self.edge_variables if isinstance(var, ContinuousVariable)]

  def __len__(self):
    return len(self.graphs)

  def make_continuous_node_features(self, nodes):
    if len(self.continuous_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.continuous_node_variables)
    continuous_node_features = torch.zeros((n_nodes, n_features), dtype=float_type)
    for node_idx, features in nodes.items():
      node_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_node_variables], dtype=float_type)
      continuous_node_features[node_idx] = node_features
    return continuous_node_features
      
  def make_categorical_node_features(self, nodes):
    if len(self.categorical_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.categorical_node_variables)
    categorical_node_features = torch.zeros((n_nodes, n_features), dtype=categorical_type)

    for node_idx, features in nodes.items():
      for i, categorical_variable in enumerate(self.categorical_node_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_node_features[node_idx, i] = value_index

    return categorical_node_features

  def make_continuous_edge_features(self, n_nodes, edges):
    if len(self.continuous_edge_variables) == 0:
      return None
    n_features = len(self.continuous_edge_variables)
    continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=float_type)
    for edge, features in edges.items():
      edge_features = torch.tensor([features[continuous_feature] for continuous_feature in self.continuous_edge_variables], dtype=float_type)
      u,v = edge
      continuous_edge_features[u, v] = edge_features
      if isinstance(edge, Set):
        continuous_edge_features[v, u] = edge_features

    return continuous_edge_features

  def make_categorical_edge_features(self, n_nodes, edges):
    if len(self.categorical_edge_variables) == 0:
      return None
    n_features = len(self.categorical_edge_variables)
    categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features), dtype=categorical_type)

    for edge, features in edges.items():
      u,v = edge
      for i, categorical_variable in enumerate(self.categorical_edge_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_edge_features[u, v, i] = value_index
          if isinstance(edge, Set):
            categorical_edge_features[v, u, i] = value_index

    return categorical_edge_features
  
  def __getitem__(self, index):
    # This is where the important stuff happens. We use our node and 
    # edge variable attributes to select what node and edge features to use.
    # In practice, we often do this as a pre-processing step, but here we do it 
    # in the getitem function for clarity

    graph = self.graphs[index]
    nodes, edges = graph
    n_nodes = len(nodes)
    continuous_node_features = self.make_continuous_node_features(nodes)
    categorical_node_features = self.make_categorical_node_features(nodes)
    continuous_edge_features = self.make_continuous_edge_features(n_nodes, edges)
    categorical_edge_features = self.make_categorical_edge_features(n_nodes, edges)

    label = self.labels[index]

    nodes_idx = sorted(nodes.keys())
    edge_list = sorted(edges.keys())

    n_nodes = len(nodes)
    adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
    for edge in edges:
      u, v = edge
      adjacency_matrix[u,v] = 1
      if isinstance(edge, Set):
        # This edge is unordered, assume this is a undirected graph
        adjacency_matrix[v,u] = 1

    adjacency_list = defaultdict(list)
    for edge in edges:
      u,v = edge
      adjacency_list[u].append(v)
      # Assume undirected graph is the edge is a set
      if isinstance(edge, Set):
        adjacency_list[v].append(u)

    data_record = {'nodes': nodes_idx,
                   'adjacency_matrix': adjacency_matrix,
                   'adjacency_list': adjacency_list,
                   'categorical_node_features': categorical_node_features,
                   'continuous_node_features': continuous_node_features,
                   'categorical_edge_features': categorical_edge_features,
                   'continuous_edge_features': continuous_edge_features,
                   'label': label}

    # If you need to add extra information (metadata about this graph) you can 
    # add an extra key-value pair here. The advantage of using a dict compared 
    # to a tuple is that the downstreams code doesn't break as long as at least 
    # the expected keys are present. The downside is that using a dict adds 
    # overhead (accessing a dict compared to unpacking a tuple).
    # A more robust implementation might actually make a separate class for 
    # dataset entires
    if self.metadata is not None:
      data_record['metadata'] = self.metadata[index]
    return data_record

  def get_node_variables(self):
    return {'continuous': self.continuous_node_variables,
            'categorical': self.categorical_node_variables}
  
  def get_edge_variables(self):
    return {'continuous': self.continuous_edge_variables,
            'categorical': self.categorical_edge_variables}
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES, bond_features=BOND_FEATURES):
  '''
  Create a new GraphDataset from a list of smiles_records dictionaries.
  These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
  '''
  graphs = []
  labels = []
  metadata = []
  for smiles_record in smiles_records:
    smiles = smiles_record['smiles']
    label = smiles_record['label']
    graph = smiles_to_graph(smiles)
    graphs.append(graph)
    labels.append(label)
    metadata.append(smiles_record)
  return GraphDataset(graphs=graphs, 
                      labels=labels, 
                      node_variables=atom_features, 
                      edge_variables=bond_features, 
                      metadata=metadata)
  
dataset = make_molecular_graph_dataset([{'smiles': 'c1ccccc1', 'label':1},{'smiles':'OS(=O)(=O)O', 'label': 0}])
from collections.abc import Set # We assume that edges as sets are for undirected graphs

def collate_graph_batch(batch):
  '''Collate a batch of graph dictionaries produdce by a GraphDataset'''
  batch_size = len(batch)

  max_nodes = max(len(graph['nodes']) for graph in batch)
  
  # We start by allocating the tensors we'll use. We defer allocating feature
  # tensors until we know the graphs actually has those kinds of features.
  adjacency_matrices = torch.zeros((batch_size, max_nodes, max_nodes), dtype=float_type)
  labels = torch.tensor([graph['label'] for graph in batch], dtype=labels_type)
  stacked_continuous_node_features = None
  stacked_categorical_node_features = None
  stacked_continuous_edge_features = None
  stacked_categorical_edge_features = None

  nodes_mask = torch.zeros((batch_size, max_nodes), dtype=mask_type)
  edge_mask = torch.zeros((batch_size, max_nodes, max_nodes), dtype=mask_type)
  
  has_metadata = False

  for i, graph in enumerate(batch):
    if 'metadata' in graph:
      has_metadata = True
    # We'll take basic information about the different graphs from the adjacency 
    # matrix
    adjacency_matrix = graph['adjacency_matrix']
    g_nodes, g_nodes = adjacency_matrix.shape
    adjacency_matrices[i, :g_nodes, :g_nodes] = adjacency_matrix

    # Now when we know how many of the entries are valid, we set those to 1s in
    # the masks
    edge_mask[i, :g_nodes, :g_nodes] = 1
    nodes_mask[i, :g_nodes] = 1
    

    # All the feature constructions follow the same recipie. We essentially
    # locate the entries in the stacked feature tensor (containing all graphs)
    # and set it with the features from the current graph.
    g_continuous_node_features = graph['continuous_node_features']
    if g_continuous_node_features is not None:
      if stacked_continuous_node_features is None:
        g_nodes, num_features = g_continuous_node_features.shape
        stacked_continuous_node_features = torch.zeros((batch_size, max_nodes, num_features))
      stacked_continuous_node_features[i, :g_nodes] = g_continuous_node_features
    
    g_categorical_node_features = graph['categorical_node_features']
    if g_categorical_node_features is not None:
      if stacked_categorical_node_features is None:
        g_nodes, num_features = g_categorical_node_features.shape
        stacked_categorical_node_features = torch.zeros((batch_size, max_nodes, num_features), dtype=categorical_type)
      stacked_categorical_node_features[i, :g_nodes] = g_categorical_node_features

    g_continuous_edge_features = graph['continuous_edge_features']
    if g_continuous_edge_features is not None:
      if stacked_continuous_edge_features is None:
        g_nodes, g_nodes, num_features = g_continuous_edge_features.shape
        stacked_continuous_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features))
      stacked_continuous_edge_features[i, :g_nodes, :g_nodes] = g_continuous_edge_features

    g_categorical_edge_features = graph['categorical_edge_features']
    if g_categorical_edge_features is not None:
      if stacked_categorical_edge_features is None:
        g_nodes, g_nodes, num_features = g_categorical_edge_features.shape
        stacked_categorical_edge_features = torch.zeros((batch_size, max_nodes, max_nodes, num_features), dtype=categorical_type)
      stacked_categorical_edge_features[i, :g_nodes, :g_nodes] = g_categorical_edge_features


  batch_record = {'adjacency_matrices': adjacency_matrices,
          'categorical_node_features': stacked_categorical_node_features,
          'continuous_node_features': stacked_continuous_node_features,
          'categorical_edge_features': stacked_categorical_edge_features,
          'continuous_edge_features': stacked_continuous_edge_features,
          'nodes_mask': nodes_mask,
          'edge_mask': edge_mask,
          'labels': labels}
  if has_metadata:
    batch_record['metadata'] = [g['metadata'] for g in batch]

  return batch_record
example_batch = collate_graph_batch([dataset[0], dataset[1]])
from torch.nn import Module
class Embedder(Module):
  def __init__(self, categorical_variables, embedding_dim):
    super().__init__()
    self.categorical_variables = categorical_variables
    embeddings = []
    for var in categorical_variables:
      num_embeddings = len(var)
      if var.has_null_value:
        # It's not uncommon to have missing values, we support this assinging a special 0-index which have the zero-vector as its embedding
        embedding = Embedding(num_embeddings, embedding_dim, padding_idx=var.get_null_idx())
      else:
        embedding = Embedding(num_embeddings, embedding_dim)
      embeddings.append(embedding)
    self.embeddings = ModuleList(embeddings)
    
  
  def forward(self, categorical_features):
    # The node features is a matrix with as many rows as nodes of our graph
    # and as many columns as we have categorical features
    all_embedded_vars = []
    for i, embedding in enumerate(self.embeddings):
      # We pick out just the i'th column. The ellipsis '...' in a numpy-style 
      # slice is a useful way of saying you want full range over all other axises
      # We use it so that this can actually take a categorical_features array
      # with arbitrary number of trailing axises to support both the node 
      # features, the edge features and the mini-batched version of both
      var_indices = categorical_features[..., i]  
      embedded_vars = embedding(var_indices)
      all_embedded_vars.append(embedded_vars)

    # If you like, you can implement concatenation instead of sum here
    stacked_embedded_vars = torch.stack(all_embedded_vars, dim=0)
    embedded_vars = torch.sum(stacked_embedded_vars, dim=0)
    return embedded_vars
class FeatureCombiner(Module):
  def __init__(self, categorical_variables, embedding_dim):
    super().__init__()
    self.categorical_variables = categorical_variables
    self.embedder = Embedder(self.categorical_variables, embedding_dim)
    
  def forward(self, continuous_features, categorical_features, ):
    # We need to be agnostic to whether we have categorical features and continuous features (it's not uncommon to only use one kind)
    features = []
    if categorical_features is not None:
      embedded_features = self.embedder(categorical_features)
      features.append(embedded_features)
      # The embedded features are now of shape (n_nodes, embedding_dim)
    if continuous_features is not None:
      features.append(continuous_features)
    if len(features) == 0:
      raise RuntimeError('No features to combine')
    full_features = torch.cat(features, dim=-1)  # Now we concatenate along the feature dimension
    return full_features

Training step

This section downloads the data and implements the training procedures. This code should also be familiar from the previous lessons.

from collections import defaultdict, deque # We'll use this to construct the dataset splits
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmilesFromSmiles
from pathlib import Path
import pandas as pd
potential_paths = [Path('BBBP.csv'), Path('../dataset/BBBP.csv')]
bbbp_table = None
for p in potential_paths:
    if p.exists():
        bbbp_table = pd.read_csv(p)
        break
bbbp_table
num name p_np smiles
0 1 Propanolol 1 [Cl].CC(C)NCC(O)COc1cccc2ccccc12
1 2 Terbutylchlorambucil 1 C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl
2 3 40730 1 c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO...
3 4 24 1 C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C
4 5 cloxacillin 1 Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)...
... ... ... ... ...
2045 2049 licostinel 1 C1=C(Cl)C(=C(C2=C1NC(=O)C(N2)=O)[N+](=O)[O-])Cl
2046 2050 ademetionine(adenosyl-methionine) 1 [C@H]3([N]2C1=C(C(=NC=N1)N)N=C2)[C@@H]([C@@H](...
2047 2051 mesocarb 1 [O+]1=N[N](C=C1[N-]C(NC2=CC=CC=C2)=O)C(CC3=CC=...
2048 2052 tofisoline 1 C1=C(OC)C(=CC2=C1C(=[N+](C(=C2CC)C)[NH-])C3=CC...
2049 2053 azidamfenicol 1 [N+](=NCC(=O)N[C@@H]([C@H](O)C1=CC=C([N+]([O-]...

2050 rows × 4 columns

# There are about 50 problematic SMILES. We've supressed RDKit logger outputs
# but if you haven't you'll get a lot of printouts here
# 11 SMILES can't be processed at all so we throw them away

smiles_records = []
for i, num, name, p_np, smiles in bbbp_table.to_records():
  # check if RDKit accepts this smiles
  if MolFromSmiles(smiles) is not None:
    smiles_record = {'smiles': smiles, 'label': p_np, 'metadata': {'row': i}}
    smiles_records.append(smiles_record)
  else:
    print(f'Molecule {smiles} on row {i} could not be parsed by RDKit')
  
Molecule O=N([O-])C1=C(CN=C1NCCSCc2ncccc2)Cc3ccccc3 on row 59 could not be parsed by RDKit
Molecule c1(nc(NC(N)=[NH2])sc1)CSCCNC(=[NH]C#N)NC on row 61 could not be parsed by RDKit
Molecule Cc1nc(sc1)\[NH]=C(\N)N on row 391 could not be parsed by RDKit
Molecule s1cc(CSCCN\C(NC)=[NH]\C#N)nc1\[NH]=C(\N)N on row 614 could not be parsed by RDKit
Molecule c1c(c(ncc1)CSCCN\C(=[NH]\C#N)NCC)Br on row 642 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1ccccc1 on row 645 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N on row 646 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)NC(C)=O on row 647 could not be parsed by RDKit
Molecule n1c(csc1\[NH]=C(\N)N)c1cccc(c1)N\C(NC)=[NH]\C#N on row 648 could not be parsed by RDKit
Molecule s1cc(nc1\[NH]=C(\N)N)C on row 649 could not be parsed by RDKit
Molecule c1(cc(N\C(=[NH]\c2cccc(c2)CC)C)ccc1)CC on row 685 could not be parsed by RDKit
import random
random.seed(1729)

training_fraction = 0.8
dev_fraction = 0.1
n_examples = len(smiles_records)
n_training_examples = int(n_examples*training_fraction)
n_dev_examples = int(n_examples*dev_fraction)

indices = list(range(n_examples))
random.shuffle(indices)  # shuffle is in place
training_indices = indices[:n_training_examples]
dev_indices = indices[n_training_examples:n_training_examples+n_dev_examples]
test_indices = indices[n_training_examples+n_dev_examples:]

training_smiles_records = [smiles_records[i] for i in training_indices]
dev_smiles_records = [smiles_records[i] for i in dev_indices]
test_smiles_records = [smiles_records[i] for i in test_indices]
training_graph_dataset = make_molecular_graph_dataset(training_smiles_records)
dev_graph_dataset = make_molecular_graph_dataset(dev_smiles_records)
test_graph_dataset = make_molecular_graph_dataset(test_smiles_records)
from torch.utils.data import DataLoader

batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_dataloader = DataLoader(training_graph_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_dataloader = DataLoader(dev_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_dataloader = DataLoader(test_graph_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
class GraphPredictionHeadConfig:
  def __init__(self, *, d_model, ffn_dim, pooling_type='sum'):
    # Pooling type can be 'sum' or 'mean'
    self.d_model = d_model
    self.ffn_dim = ffn_dim
    self.pooling_type = pooling_type

class GraphPredictionHead(Module):
  def __init__(self, input_dim, output_dim, config):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.config = config
    self.predictor = Sequential(Linear(self.input_dim, self.config.ffn_dim), 
                                ReLU(), 
                                Linear(self.config.ffn_dim, self.output_dim))

  def forward(self, node_features, node_mask):
    # The node_features is a tensor of shape (*leading_axises, max_nodes, d_model)
    # We want to 'pool' this along the node-axis, which is dim=-2 in pytorch terms
    # In this case we assume these features are valid, i.e. that they have been
    # masked at a previous step.
    if self.config.pooling_type == 'sum':
      pooled_nodes = node_features.sum(dim=-2)
    elif self.config.pooling_type == 'mean':
      # We can't take just the mean along dim=-2, since if this is a batch, some of 
      # the graphs need to be divided by a smaller number than max_nodes. 
      # Thankfully we have the information about how many nodes each graph has in
      # the node_mask, and since it has 1's and 0's, just summing it along the 
      # max_nodes axis gives the count of nodes for the corresponding graph
      
      # node_mask has the shape (batch_size, max_nodes), or just (max_nodes,) 
      # if it's not a batch. We get the count per graph by reducing along the 
      # last dimension, and by setting keepdims=True, we get a shape 
      # (batch_size, 1) or (1,) which will allow for broadcasting this with 
      # division over the summed feature vectors to calculate their mean 
      node_counts = node_mask.sum(dim=-1)  
      summed_feature_vectors = node_features.sum(dim=-2)
      pooled_nodes = summed_feature_vectors/node_counts
    else:
      raise ValueError(f'Unsupported pooling type {self.config.pooling_type}')
    
    prediction = self.predictor(pooled_nodes)
    return prediction
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
print("Device is", device)
Device is cuda
from tqdm.notebook import tqdm, trange
from torch.nn import BCEWithLogitsLoss, MSELoss
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error, median_absolute_error
import matplotlib.pyplot as plt
import seaborn as sns

def batch_to_device(batch, device):
  moved_batch = {}
  for k, v in batch.items():
    if torch.is_tensor(v):
      v = v.to(device)
    moved_batch[k] = v
  return moved_batch

class Trainer:
  def __init__(self, *, model, 
               loss_fn, training_dataloader, 
               dev_dataloader, device=device):
    self.model = model
    self.training_dataloader = training_dataloader
    self.dev_dataloader = dev_dataloader
    self.device = device
    self.model.to(device)
    self.total_epochs = 0
    self.optimizer = AdamW(self.model.parameters(), lr=1e-4)
    self.loss_fn = loss_fn

  def train(self, epochs):
    with trange(epochs, desc='Epoch', position=0) as epoch_progress:
      batches_per_epoch = len(self.training_dataloader) + len(self.dev_dataloader)
      for epoch in epoch_progress:
        train_loss = 0
        train_n = 0
        for i, training_batch in enumerate(tqdm(self.training_dataloader, desc='Training batch', leave=False)):
          self.optimizer.zero_grad()
          # Move all tensors to the device
          self.model.train()
          training_batch = batch_to_device(training_batch, self.device)
          prediction = self.model(training_batch)
          labels = training_batch['labels']
          loss = self.loss_fn(prediction.squeeze(), labels) # By default the predictions have shape (batch_size, 1)
          loss.backward()
          self.optimizer.step()
          batch_n = len(labels)
          train_loss += batch_n * loss.cpu().item()
          train_n += batch_n
        #print(f"Training loss for epoch {total_epochs}", train_loss/train_n)
        self.total_epochs += 1

        dev_predictions = []
        dev_labels = []
        dev_n = 0
        dev_loss = 0
        for i, dev_batch in enumerate(tqdm(self.dev_dataloader, desc="Dev batch", leave=False)):
          self.model.eval()
          with torch.no_grad():
            dev_batch = batch_to_device(dev_batch, self.device)
            prediction = self.model(dev_batch).squeeze()
            dev_predictions.extend(prediction.tolist())
            labels = dev_batch['labels']
            dev_labels.extend(labels.tolist())
            loss = self.loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
            batch_n = len(labels)
            dev_loss += batch_n*loss.cpu().item()
            dev_n += batch_n
        epoch_progress.set_description(f"Epoch: train loss {train_loss/train_n: .3f}, dev loss {dev_loss/dev_n: .3f}")


def evaluate_model(trainer, dataloader, label=None, hue_order=[0,1]):
  eval_predictions = []
  eval_labels = []
  eval_loss = 0
  eval_n = 0
  model = trainer.model
  loss_fn = trainer.loss_fn
  total_epochs = trainer.total_epochs
  for i, eval_batch in enumerate(tqdm(dataloader, desc='batch')):
    model.eval()
    with torch.no_grad():
      eval_batch = batch_to_device(eval_batch, device)
      prediction = model(eval_batch).squeeze()
      eval_predictions.extend(prediction.tolist())
      labels = eval_batch['labels']
      eval_labels.extend(labels.tolist())
      loss = loss_fn(prediction, labels) # By default the predictions have shape (batch_size, 1)
      batch_n = len(labels)
      eval_loss += batch_n*loss.cpu().item()
      eval_n += batch_n
  average_loss = eval_loss/eval_n
  roc_auc = roc_auc_score(eval_labels, eval_predictions)
  eval_df = pd.DataFrame(data={'target': eval_labels, 'predictions': eval_predictions})
  sns.kdeplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
  sns.rugplot(data=eval_df, x='predictions', hue='target', hue_order=hue_order)
  
  if label is not None:
    title = f"{label} dataset after {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
  else:
    title = f"After {total_epochs} epochs\nloss {average_loss}\nROC AUC {roc_auc}"
  plt.title(title)

From GNNs to Transformers

We’ve seen how the idea of using sums of vectors can be used to learn things from a graph, we used it on graph neighbourhoods to introduce more structure and we used it as a pooling method for graph predictions. We also saw how the aggregation of a graph neighourhood good be done by multiplying with the adjacency matrix.

We saw how this summing of local neighbourhoods in the graph can be done using a matrix multiplication of the graphs adjacency matrix and the neighbouring node representations

$$ \mathbf{h}i^L = f(\sum{j \in N(i)} \mathbf{h}_j^{L-1}, \mathbf{h}i^{L-1}) $$ $$ \sum{j \in N(i)} \mathbf{h}_j^{L-1} = A_i H^{L-1}$$ $$ H^{L-1} = \begin{bmatrix}\mathbf{h}_1^{L-1}\mathbf{h}_2^{L-1}\vdots\mathbf{h}_n^{L-1}\end{bmatrix} $$

Assuming that the adjacency matrix $A$ is a binary indicator matrix with 1’s for pairs $i,j$ which are connected by an edge, and 0 if they are not connected.

A way to think about the “convolutional” part of a Graph Neural Network is then that it’s essentially performing a matrix multiplication with the adjacency matrix: $$H^L = f(AH^{L-1}, H^{L-1})$$

Where $f$ is a function we wish to learn (e.g. a neural network with two input vectors).

The multiplication $AH^{L-1}$ is what takes the graph structure into account by only aggregating the local neighbourhoods. We’ve seen previously how this method has some fundamental limitations, in particular the “convolution” has a receptive field, so is biased towards learning local patterns.

Dynamically computed “adjacency matrix”

What if instead of assuming the matrix A to be the adjacenct matrix, we calcualte it’s values based on the values of the node pairs and the edge? This would effectively mean that we will potentially aggregate “neighbourhoods” which is the complete graph.

$$A = \begin{bmatrix} g(h_1, h_1, e_{1,1})& g(h_1, h_2, e_{1,2}) & \dots \ \vdots& \ddots & \ g(h_n, h_1, e_{n,1})& \dots & g(h_n, h_n, e_{n,n}) \end{bmatrix}$$

This way we can learn to use node information as well as edge information when deciding on how to aggregate the node set. While it allows us to take the graph structure into account, it also allows the model to learn relationships about more distant nodes.

This is the fundamental idea of the Transformer architecture. By dynamically setting the values of this “adjacency matrix”, we can have a model which can learn to induce structure from arbitrary sets.

Depending on how we construct $g$, we can choose to inject knowledge about the elements of the input set $H$ such as relative position between tokens if H is actually a sequence, or information about a particular pair, such as whether they are connected by an edge in a graph.

The downside

While this idea is very powerful, it comes with a major limitation. A GNN typically only implictly muiltiply the node vectors $H$ with the adjecency matrix. In practice this is implemented by a sparse operation as we saw in the previous notebook. This means that the computational cost is $O(n d k)$, where $n$ is the number of nodes and $d$ the average (TODO, maximum?) degree and $k$ the dimensionality of the node vectors.

To multiply with a dense matrix this instead becomes $O(n^2 k)$ and this quadratic scaling on the number of nodes in the input severely limits the application of this idea.

Generating $A$ also scales quadratically with the number of nodes, since we need to apply the function to all pairs of node vectors, regardless of edges.

This fact has not stopped this idea of becoming wildly successful in the domain of Natural Language Processing, and as long as the domain we’re working has relatively small inputs (like organic molecules in medicinal chemistry) we can handle the quadratic scaling with input with brute force (a lot of computational capacity).

Implementing a Transformer

Below is the implementation of the transformer. In aparticular the BasicTransformerLayer is where the main difference from a GNN is implemented. The BasicTransformerEncoder is very similar to the GNN encoder.

import math

import torch
from torch.nn import Module, Embedding, ModuleList, Linear, Sequential, ReLU, LayerNorm, Dropout
from torch.nn.functional import softmax

class BasicTransformerConfig:
  def __init__(self, *, 
               d_model: int, 
               n_layers: int, 
               ffn_dim: int,
               head_dim: int,
               layer_normalization: bool = True,
               dropout_rate: float = 0.1,
               residual_connections: bool=True):
    self.d_model = d_model
    self.n_layers = n_layers
    self.ffn_dim = ffn_dim
    # Note that we introduce a new hyper parameter called *head_dim*, in
    # Transformers we typically transform the "node" feature vectors to
    # a lower dimensional space
    self.head_dim = head_dim
    self.layer_normalization = layer_normalization
    self.dropout_rate = dropout_rate
    self.residual_connections = residual_connections
    
class BasicTransformerLayer(Module):
  def __init__(self, config):
    super().__init__()    
    self.config = config
    self.input_dim = config.d_model
    self.output_dim = config.d_model
    self.ffn_dim = config.ffn_dim
    self.head_dim = config.head_dim

    # Transformers typically don't use mlps to create the neighbours and center
    # embeddings, instead relying on just linear transformations
    self.neighbour_transform = Linear(self.input_dim, self.head_dim, bias=False)
    self.center_transform = Linear(self.input_dim, self.head_dim, bias=False)
    
    # The transformer uses layer normalization by default
    self.attention_norm = LayerNorm(self.input_dim)
    
    self.output_transform = Sequential(Linear(self.input_dim, self.ffn_dim),
                                       ReLU(), 
                                       Linear(self.ffn_dim, self.output_dim))
                                       
    self.output_norm = LayerNorm(self.output_dim)
    self.dropout = Dropout(self.config.dropout_rate)
   
    self.scaling_factor = math.sqrt(self.input_dim)
  
  def attention_function(self, adjacency_matrix, center_node_features, 
                         neighbour_node_features, edge_features, node_mask, edge_mask):
    #The standard Transformer just 
    # take the dot product between the center node and the neighbour nodes, 
    # scaled by the square root of the model dimension.
    # To take the dot products between all center nodes and all neighbour nodes
    # We perform an outer product by first transposing one of the matrices along
    # the last two axises
    # In the single-graph case the matrix multplication between 
    # a matrix with shape (n_nodes, head_dim) times (head_dim, n_nodes)
    # gives the resulting matrix of shape (n_nodes, n_nodes), where each element
    # is the dot product of the column for a node in one first with the row of a 
    # node in the other
    attention_logits = torch.matmul(center_node_features, 
                                    neighbour_node_features.transpose(-1, -2))/self.scaling_factor

    # The "adjcency matrix" of the transformer is actually using weighted means
    # for the aggregation, and the way we achieve this is to make sure that the
    # rows of the matrix are  normalized using softmax
    # However, if we have a batch of graphs as inputs, we will have some 
    # "positions" in the node features which we should not include in our 
    # aggregation, and should therefore mask out in our attention matrix.
    # If we did that after the softmax calculations, the rows would no longer
    # add up to 1. Instead we do it before the softmax by essentially 
    # setting the masked values to va number so low it will end up as a 
    # 0 in the softmax output. The lowest value we can imagine is negative
    # infinity, so let's use that. 
    # The goal is to mask out parts of the different batch attention_logits which
    # are not part of the nodes, so essentially have a resulting matrix per batch
    # example which looks something like
    # [[   a,    b,    c, -inf, -inf],
    #  [   d,    e,    f, -inf, -inf],
    #  [   g,    h,    i, -inf, -inf],
    #  [-inf, -inf, -inf, -inf, -inf],
    #  [-inf, -inf, -inf, -inf, -inf],
    # ]
    # To do this we will first create a boolean mask which have True in the
    # places we want to fill with '-inf'. We do this by using the node mask
    # like in the GNN examples, doing something very much like an outer product
    nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
    
    # The nodemask has 1's where there are valid elements and 0's where there 
    # are none, we invert his and convert it to bool tensor
    fill_mask = (1 - nodemask_2d).to(torch.bool)
    
    # We're now ready to 'mask out' the logits. Notice that masked_fill_ is 
    # in place
    attention_logits.masked_fill_(fill_mask, float('-inf'))
    attention_matrix = softmax(attention_logits, dim=-1)

    # There will be rows of the smaller attention matrices which where filled
    # completely with -inf values, these will now be rows of 'nan' values
    # We perform a new fill of the attention matrix, but this time with 0s 
    attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
    return attention_matrix


  def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
    # In this basic Transformer layer we'll not use the edge features, and instead
    # focus on the basic transformer formulation of this problem. This will pretty
    # much treat the graph as just a node set. We should not expect this to 
    # be able to anything which relies on the graph structure
    center_node_features = self.center_transform(node_features)
    neighbour_node_features = self.neighbour_transform(node_features)

    # The transformed node features are either a 3-tensor 
    # (batch_size, max_nodes, head_dim) or a matrix (n_nodes, head_dim) when
    # it's a single graph. We'll make this code agnostic to that
    # The goal now is to _compute_ an "adjacency matrix" using the node features
    # This could be done by any function. We define a method on this class
    # which is the attention function which has the purpose of giving us an
    # "adjacency matrix"
    attention_matrix = self.attention_function(adjacency_matrix, 
                                               center_node_features,
                                               neighbour_node_features, 
                                               edge_features, 
                                               node_mask, 
                                               edge_mask)
    
    # Now we aggregate the neighbourhoods using the attention matrix (our computed "adjacency matrix")
    # The transformer doesn't transform the node features at this stage, instead doing it in a separate MLP after residual connections and aggregation
    aggregated_neighbourhoods = torch.matmul(attention_matrix, node_features)

    # and mask the result    
    masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1)
    
    # Followed by a dropout and layer normalization
    masked_features = self.dropout(masked_features)
    masked_features = self.attention_norm(masked_features)

    # The transformer by default uses residual connections
    if self.config.residual_connections:
      masked_features = masked_features + node_features

    # Now we apply the output transform as in the GNN
    updated_node_features = self.output_transform(masked_features)

    # Mask again
    updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)

    # Followed by a dropout and normalization
    updated_node_features = self.dropout(updated_node_features)
    updated_node_features = self.output_norm(updated_node_features)

    # And the resiudal connection from the input to the output MLP
    if self.config.residual_connections:
      updated_node_features = updated_node_features + masked_features

    return updated_node_features


class BasicTransformerEncoder(torch.nn.Module):
  def __init__(self, *,
               config: BasicTransformerConfig, 
               continuous_node_variables=None,
               categorical_node_variables=None,
               continuous_edge_variables=None,
               categorical_edge_variables=None,
               layer_type=BasicTransformerLayer):
    super().__init__()

    self.config = config
    self.layer_type = layer_type

    self.continuous_node_variables = continuous_node_variables
    self.categorical_node_variables = categorical_node_variables
    self.continuous_edge_variables = continuous_edge_variables
    self.categorical_edge_variables = categorical_edge_variables

    # We want the embeddings together with the continuous values to be of dimension d_model, therefore the allocate d_model - len(continuous_variables) as the embeddings dim
    self.categorical_node_embeddings_dim = config.d_model - len(self.continuous_node_variables)
    self.categorical_edge_embeddings_dim = config.d_model - len(self.continuous_edge_variables)

    self.node_featurizer = FeatureCombiner(self.categorical_node_variables, 
                                           self.categorical_node_embeddings_dim)
    self.edge_featurizer = FeatureCombiner(self.categorical_edge_variables, 
                                           self.categorical_edge_embeddings_dim)
    
    # Notice that we use the supplied layer type above when creating the graph
    # layers. This allows us to easily change the kind of graph layers
    # we use later on
    self.graph_layers = ModuleList([layer_type(config) for l in range(config.n_layers)])
    
  def forward(self, batch):
    # First order of business is to embed the node embeddings
    node_mask = batch['nodes_mask']
    batch_size, max_nodes = node_mask.shape
    
    continuous_node_features = batch['continuous_node_features']
    categorical_node_features = batch['categorical_node_features']
    node_features = self.node_featurizer(continuous_node_features, categorical_node_features)
    masked_node_features = node_features * node_mask.unsqueeze(-1)
    
    continuous_edge_features = batch['continuous_edge_features']
    categorical_edge_features = batch['categorical_edge_features']
    edge_features = self.edge_featurizer(continuous_edge_features, categorical_edge_features)
    edge_mask = batch['edge_mask']
    masked_edge_features = edge_features * edge_mask.unsqueeze(-1)

    # We have now embedded the node features, we'll propagate them through our 
    # graph layers
    adjacency_matrix = batch['adjacency_matrices']
    memory_state = masked_node_features
    for l in self.graph_layers:
      memory_state = l(adjacency_matrix, memory_state, masked_edge_features , node_mask, edge_mask)

    return memory_state
class GraphPredictionNeuralNetwork(Module):
  def __init__(self, encoder, prediction_head):
    super().__init__()
    self.encoder = encoder
    self.prediction_head = prediction_head

  def forward(self, batch):
    encoded_graph = self.encoder(batch)
    prediction = self.prediction_head(encoded_graph, batch['nodes_mask'])
    return prediction
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model, 
                                      n_layers=2, 
                                      ffn_dim=16,
                                      head_dim=8,
                                      layer_normalization=True,
                                      dropout_rate=0.1,
                                      residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config, 
                                            continuous_node_variables=dataset.continuous_node_variables,
                                            categorical_node_variables=dataset.categorical_node_variables,
                                            continuous_edge_variables=dataset.continuous_edge_variables,
                                            categorical_edge_variables=dataset.categorical_edge_variables)

head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)

model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)

loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)
trainer.train(1)

Adding the graph structure to Transformers

In basic GNNs, we take the graph strucuture into account by using the given adjacency matrix. In the Transformer we’ve seen now, we compute an attention matrix, a matrix which plays the same role as the adjacency matrix of the basic graph neural network.

We can simply modify our attention function of the Transformer to also take the adjacency matrix into account.

This attention function could really be anything, but for now we’ll stick with a very simple idea of modifying the logits of our scaled dot-product attention by adding a scalar to the logits if there is a 1 in the adjacency matrix for that place.

If $a_{i,j}$ is the the value before the softmax of the attention logit matrix for how much we should include information from node $j$ when aggregating for node $i$, we define a function:

$$ a_{i,j} = f(\mathbf{x_i}, \mathbf{x_j}) = \frac{<\mathbf{x_i}, \mathbf{x_j}>}{\sqrt{\text{d_model}}} + w\mathbf{1}_{{i, j} \in E} $$

Here ${1}_{{i, j} \in E}$ is the indicator function taking the value $1$ if there is an edge between $i$ and $j$ (in other words, the $i,j$ entry of the adjacency matrix).

We also add a learnable scalar $w$ which the network can learn to set to change what influence the prescence of an edge has on the attention score.

from torch.nn import Parameter
class AdjacencyTransformerLayer(BasicTransformerLayer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.adjacency_weight = Parameter(torch.tensor(1.))

  def attention_function(self, adjacency_matrix, center_node_features, 
                         neighbour_node_features, edge_features, node_mask, edge_mask):
    # We still take the 
    dot_product_logits = torch.matmul(center_node_features, 
                                                neighbour_node_features.transpose(-1, -2))/self.scaling_factor

    # The adjacency_matrix is scaled by a parameter so that the network can
    # learn to adjust the influence of edge presence
    adjacency_logits = self.adjacency_weight*adjacency_matrix
    attention_logits = dot_product_logits + adjacency_logits

    nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
    fill_mask = (1 - nodemask_2d).to(torch.bool)
    attention_logits.masked_fill_(fill_mask, float('-inf'))
    attention_matrix = softmax(attention_logits, dim=-1)

    # There will be rows of the smaller attention matrices which where filled
    # completely with -inf values, these will now be rows of 'nan' values
    # We perform a new fill of the attention matrix, but this time with 0s 
    attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
    return attention_matrix
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model, 
                                      n_layers=2, 
                                      ffn_dim=16,
                                      head_dim=8,
                                      layer_normalization=True,
                                      dropout_rate=0.1,
                                      residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config, 
                                            continuous_node_variables=dataset.continuous_node_variables,
                                            categorical_node_variables=dataset.categorical_node_variables,
                                            continuous_edge_variables=dataset.continuous_edge_variables,
                                            categorical_edge_variables=dataset.categorical_edge_variables,
                                            layer_type=AdjacencyTransformerLayer)

config = GraphPredictionHeadConfig(d_model=32, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=head_config.d_model, output_dim=1, config=head_config)

model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)

loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)
trainer.train(1)

Edge features in the attention

In the previous example, we only used the adjacency matrix to influence the attention function. Now we’ll extend this to actually include edge features as well.

To do this we simply extend this notion of using a function to compute the values of the attention matrix, in particular we will use function which takes the feature nodes for the center and neighbour node as well as the edge feature vector between these nodes. For nodes which have no feature vector, the zero-vector will be used.

This function will be implemented by a simple 2 layer MLP. And we choose to implement it as below

$$ a_{i,j} = f(\mathbf{x_i}, \mathbf{x_j}, \mathbf{x_{e_{i,j}}}) = W_2 \sigma(W_1 (\mathbf{x_i} + \mathbf{x_j} + \mathbf{x_{e_{i,j}}}) + \mathbf{b_1}) + \mathbf{b_2} $$

You can see that we apply the neural network on the sum the vectors of node features and edge features. This will make this function permutation invariant of its inputs (i.e. $f(\mathbf{x_i}, \mathbf{x_j}, \mathbf{x_{e_{i,j}}}) = f( \mathbf{x_j}, \mathbf{x_{e_{i,j}}}, \mathbf{x_i},)$. If this is not desired, concatenation can be used instead but makes the implementation a bit more complex.

class EdgeAttributesAttentionTransformerLayer(BasicTransformerLayer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

    # We're going to sum our center and neighbour vectors to the edge feature
    # vector, so have to be mindful of the dimensionality
    self.center_transform = Linear(self.input_dim, self.input_dim)
    self.neighbour_transform = Linear(self.input_dim, self.input_dim)
    self.attention_score_function = Sequential(Linear(self.input_dim, 
                                                      self.config.ffn_dim), 
                                               ReLU(), 
                                               Linear(self.config.ffn_dim, 
                                                      1))

    
  def attention_function(self, adjacency_matrix, center_node_features, 
                         neighbour_node_features, edge_features, 
                         node_mask, edge_mask):
    # Our goal is to first build up the input tensor to our simple MLP. 
    # This will be a tensor of shape 
    # (*leading_dimension, n_nodes, n_nodes, feature_dim)
    # The edge_feature's tensor already has this shape, and we've made sure
    # the edge features on that tensor is already zeroed for places where
    # there are no edges.
    # We want each element [...,i,j,:] of this matrix to be 
    # center_node_features[i] + neighbour_node_features[j] + edge_features[i,j]
    # We achieve this by broadcasting once again, the center node features
    # are broadcasted along the -3'rd axis and the neighbour node features
    # along the -2'nd axis which gives us the desired result
    # This is one of the most annyoing things with the frameworks we use, 
    # having to be very conscious about the order of axises
    # unsqueeze(-2) broadcasts along the "row" of this batch of edge feature "matrices"
    attention_score_input = edge_features + center_node_features.unsqueeze(dim=-2) 
    # unsqueeze(-3) broadcasts along the "columns" if the 3-tensor
    attention_score_input = attention_score_input + neighbour_node_features.unsqueeze(dim=-3)
    
    attention_logits = self.attention_score_function(attention_score_input).squeeze()
    
    # We need to perform the same masking as before
    nodemask_2d = node_mask.unsqueeze(dim=-2) * node_mask.unsqueeze(dim=-1)
    fill_mask = (1 - nodemask_2d).to(torch.bool)
    
    attention_logits.masked_fill_(fill_mask, float('-inf'))
    attention_matrix = softmax(attention_logits, dim=-1)
    attention_matrix = attention_matrix.masked_fill(fill_mask, 0.)
    return attention_matrix
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model, 
                                      n_layers=2, 
                                      ffn_dim=16,
                                      head_dim=8,
                                      layer_normalization=True,
                                      dropout_rate=0.1,
                                      residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config, 
                                            continuous_node_variables=dataset.continuous_node_variables,
                                            categorical_node_variables=dataset.categorical_node_variables,
                                            continuous_edge_variables=dataset.continuous_edge_variables,
                                            categorical_edge_variables=dataset.categorical_edge_variables,
                                            layer_type=EdgeAttributesAttentionTransformerLayer)

head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)

model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)

loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)
trainer.train(1)

Edge features in the transformations

While the basic Transformer can include pairwise information (such as the adjacency matrix) in its attention function, this is only used to decide what node vectors to aggregate, i.e. the values of the attention matrix (our dynamic adjacency matrix).

In the GNN we looked at in the last notebook, edge feature vectors where integrated into the message “passed” from the neighbour nodes, and could be used by the MLP parts of the network to learn a combined representation of the node features and edge features and eventually directly part of the output.

We can extend the Transformer architecture to also do this just like we did in the example on GNNs. We will still use the underlying idea of computing an adjacency matrix to have a dynamic aggregation function, but instead making the aggregated neighbourhood be contexually different based on what node we’re currently aggregating for.

from torch.nn import Parameter
class EdgeAttributesTransformerLayer(EdgeAttributesAttentionTransformerLayer):
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

  def forward(self, adjacency_matrix, node_features, edge_features, node_mask, edge_mask):
    center_node_features = self.center_transform(node_features)
    neighbour_node_features = self.center_transform(node_features)
    attention_matrix = self.attention_function(adjacency_matrix, 
                                               center_node_features,
                                               neighbour_node_features, 
                                               edge_features, 
                                               node_mask, 
                                               edge_mask)
    # Just like when we computed the attention scores using the edge feature 
    # vectors we also want to create context-dependent neighbourhoods now
    # We do that by creating the full pairwise tensor of shape
    # (*leading_dimensions, n_nodes, n_nodes, num_features). We use the same
    # procedure with broadcasting in the 
    # EdgeAttributesAttentionTransformerLayer attention_function above
    # unsqueeze(-2) broadcasts along the "row" of this batch of edge feature "matrices"
    context_dependent_features = edge_features + center_node_features.unsqueeze(dim=-2) 
    # unsqueeze(-3) broadcasts along the "columns" if the 3-tensor
    context_dependent_features = context_dependent_features + neighbour_node_features.unsqueeze(dim=-3)
    
    # Now the aggregation becomes a bit more tricky. We did this in the GNN layer
    # which used edge features as well, essentially explicitly performing what 
    # used a matrix multiplaction to do previously: broadcast the attention matrix
    # over the feature dimension: multiplying a feature vector at position i,j 
    # with the value in the attention matrix at position i,j
    attended_features = context_dependent_features * attention_matrix.unsqueeze(dim=-1)

    # Now the aggregation is performed by summing these attended features along 
    # the "rows", i.e. reducing away the "column" axis which is dim=-2
    aggregated_neighbourhoods = attended_features.sum(dim=-2)  

    # Now maske the result as before
    masked_features = aggregated_neighbourhoods * node_mask.unsqueeze(dim=-1)
    
    # Followed by a dropout, layer normalization and residual sum
    masked_features = self.dropout(masked_features)
    masked_features = self.attention_norm(masked_features)
    if self.config.residual_connections:
      masked_features = masked_features + node_features

    # Transform the features with our MLP
    updated_node_features = self.output_transform(masked_features)

    updated_node_features = updated_node_features * node_mask.unsqueeze(dim=-1)

    # Followed by a dropout and normalization
    updated_node_features = self.dropout(updated_node_features)
    updated_node_features = self.output_norm(updated_node_features)

    # And the resiudal connection from the input to the output MLP
    if self.config.residual_connections:
      updated_node_features = updated_node_features + masked_features

    return updated_node_features

  
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model, 
                                      n_layers=2, 
                                      ffn_dim=16,
                                      head_dim=8,
                                      layer_normalization=True,
                                      dropout_rate=0.1,
                                      residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config, 
                                            continuous_node_variables=dataset.continuous_node_variables,
                                            categorical_node_variables=dataset.categorical_node_variables,
                                            continuous_edge_variables=dataset.continuous_edge_variables,
                                            categorical_edge_variables=dataset.categorical_edge_variables,
                                            layer_type=EdgeAttributesTransformerLayer)

head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)

model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)

loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_dataloader,
                  dev_dataloader=dev_dataloader)
trainer.train(1)

Pairwise features

In the Graph Neural Network framework, we tend to focus on node attributes and edge attributes. With the Transformer framework we have instead an architecture which is based on all pairwise interaction between nodes. This allows us to extend the idea of what edge features could be.

Instead of only specifying features for pairs of nodes which have an edge, we can think of features between any pair of nodes. This can be used to give the network information about node relationships which it struggles to learn, but can easily compute with regular algorithms, such as distance between nodes if they are embedded in an euclidean space, or the shortest path between them if they are in a graph.

This latter idea is the idea of a path-augmented graph transformer Chen Benson et al. “Path-augmented graph transformer network.”.

We can use the architecture we’ve already defined above but allow our edge features to actually be pairwise features. Instead of only having features for pairs of nodes which are directly connected to the graph, we introduce features between any pair of nodes.

In this case the pairwise features will contain information about the edges along the shortest path of our graph, to at most max_path_length steps (longer paths are just truncated).

Representing the path information

In the graph, the shortest path is defined by the edges between two nodes. A natural way of representing the path is then a sequence of the edge features along this shortest path.

To handle the sequence information of these edges we could aggregate the path information using a sequence model such as a Recurrent Neural Network or a Transformer. In this example we make things a bit simpler and just represent the path features as separate categorical variables. We’ll add these variables between all pairs of nodes.

Since we have to decide on exactly what categorical variables to use when defining the network architecture, we decide on some max length $k$ and only take edge features along the path up to that number.

max_path_length = 3  # This is the k we limit our path lengths to
from itertools import combinations
from rdkit.Chem.rdmolops import GetShortestPath
from itertools import combinations
from tqdm.notebook import tqdm


PAIRWISE_FEATURES = []

BOND_FEATURES = {'type_feature': TYPE_FEATURE, 
                 'direction_feature': DIRECTION_FEATURE, 
                 'aromatic_feature': AROMATIC_FEATURE, 
                 'stereo_feature': STEREO_FEATURE}

PAIRWISE_FEATURES.extend(BOND_FEATURES.values())

# We careate a copy of the bond features for each path step
PATH_FEATURES = []
for i in range(max_path_length):
  path_vars = {}
  for feature_kw, var in BOND_FEATURES.items():
    name = var.name
    path_var_name = f"{name}_p{i}"
    if isinstance(var, ContinuousVariable):
      path_var = ContinuousVariable(path_var_name)
    if isinstance(var, CategoricalVariable):
      path_var_values = var.values
      path_var = CategoricalVariable(path_var_name, path_var_values, add_null_value=False)
    path_vars[feature_kw] = path_var
    PAIRWISE_FEATURES.append(path_var)
  PATH_FEATURES.append(path_vars)


def get_shortest_paths_bond_features(rd_bond,
                      *,
                      type_feature,
                      direction_feature,
                      aromatic_feature,
                      stereo_feature):
  
  if rd_bond is not None:
    bond_type = str(rd_bond.GetBondType())
    bond_stereo_info = str(rd_bond.GetStereo())
    bond_direction = str(rd_bond.GetBondDir())
    is_aromatic = rd_bond.GetIsAromatic()
  else:
    bond_type = None
    bond_stereo_info = None
    bond_direction = None
    is_aromatic = None

  return {type_feature: bond_type,
          direction_feature: bond_direction,
          aromatic_feature: is_aromatic,
          stereo_feature: bond_stereo_info}


def get_pairwise_features(rd_mol, rd_atom_a, rd_atom_b):
  pairwise_features = {}
  # First we create the features for the bond (or missing such) between
  # the two atoms
  bond = rd_mol.GetBondBetweenAtoms(rd_atom_a.GetIdx(), rd_atom_b.GetIdx())
  bond_features = get_shortest_paths_bond_features(bond, **BOND_FEATURES)
  
  pairwise_features.update(bond_features)
  # Now we create bond features for the path between rd_atom_a and rd_atom_b
  # We iterate over atoms of the shortest path up till max_path_length
  # If the shortest path is shorter than max_path_length, we add None-valued
  # features for the remaining ones
  shortest_path = GetShortestPath(rd_mol, rd_atom_a.GetIdx(), rd_atom_b.GetIdx())
  for i in range(max_path_length):
    path_bond_variables = PATH_FEATURES[i]
    if i < (len(shortest_path) - 1):
      a, b = shortest_path[i], shortest_path[i+1]
      path_bond = rd_mol.GetBondBetweenAtoms(a, b)
    else:
      path_bond = None
    path_bond_features = get_shortest_paths_bond_features(path_bond, **path_bond_variables)
    pairwise_features.update(path_bond_features)

  
  return pairwise_features
rd_mol = MolFromSmiles('CCCCCC')
atom_a = rd_mol.GetAtomWithIdx(0)
atom_b = rd_mol.GetAtomWithIdx(1)
get_pairwise_features(rd_mol, atom_a, atom_b)
{<CategoricalVariable: bond_direction>: 'NONE',
 <CategoricalVariable: bond_direction_p0>: 'NONE',
 <CategoricalVariable: bond_direction_p1>: None,
 <CategoricalVariable: bond_direction_p2>: None,
 <CategoricalVariable: bond_stereo>: 'STEREONONE',
 <CategoricalVariable: bond_stereo_p0>: 'STEREONONE',
 <CategoricalVariable: bond_stereo_p1>: None,
 <CategoricalVariable: bond_stereo_p2>: None,
 <CategoricalVariable: bond_type>: 'SINGLE',
 <CategoricalVariable: bond_type_p0>: 'SINGLE',
 <CategoricalVariable: bond_type_p1>: None,
 <CategoricalVariable: bond_type_p2>: None,
 <CategoricalVariable: is_aromatic>: False,
 <CategoricalVariable: is_aromatic_p0>: False,
 <CategoricalVariable: is_aromatic_p1>: None,
 <CategoricalVariable: is_aromatic_p2>: None}
atom_a = rd_mol.GetAtomWithIdx(0)
atom_b = rd_mol.GetAtomWithIdx(5)
get_pairwise_features(rd_mol, atom_a, atom_b)
{<CategoricalVariable: bond_direction>: None,
 <CategoricalVariable: bond_direction_p0>: 'NONE',
 <CategoricalVariable: bond_direction_p1>: 'NONE',
 <CategoricalVariable: bond_direction_p2>: 'NONE',
 <CategoricalVariable: bond_stereo>: None,
 <CategoricalVariable: bond_stereo_p0>: 'STEREONONE',
 <CategoricalVariable: bond_stereo_p1>: 'STEREONONE',
 <CategoricalVariable: bond_stereo_p2>: 'STEREONONE',
 <CategoricalVariable: bond_type>: None,
 <CategoricalVariable: bond_type_p0>: 'SINGLE',
 <CategoricalVariable: bond_type_p1>: 'SINGLE',
 <CategoricalVariable: bond_type_p2>: 'SINGLE',
 <CategoricalVariable: is_aromatic>: None,
 <CategoricalVariable: is_aromatic_p0>: False,
 <CategoricalVariable: is_aromatic_p1>: False,
 <CategoricalVariable: is_aromatic_p2>: False}
def rdmol_to_complete_graph(mol):
  atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom) for rd_atom in mol.GetAtoms()}
  all_pairwise_features = {}
  for atom_a, atom_b in combinations(mol.GetAtoms(), 2):
    all_pairwise_features[frozenset((atom_a.GetIdx(), atom_b.GetIdx()))] = get_pairwise_features(mol, atom_a, atom_b)
  return atoms, all_pairwise_features
def smiles_to_complete_graph(smiles):
  rd_mol = MolFromSmiles(smiles)
  graph = rdmol_to_complete_graph(rd_mol)
  return graph
import multiprocessing
from tqdm.notebook import tqdm

def process_smiles_record(smiles_record):
  smiles = smiles_record['smiles']
  rdmol = MolFromSmiles(smiles)
  label = smiles_record['label']
  graph = rdmol_to_complete_graph(rdmol)
  return graph, label, smiles_record

def make_graph_shortest_path_dataset(smiles_records, 
                                              atom_features=ATOM_FEATURES, 
                                              bond_features=PAIRWISE_FEATURES):
  '''
  Create a new GraphDataset from a list of smiles_records dictionaries.
  These records should contain the key 'smiles' and 'label'. Any other keys will be saved as a 'metadata' record.
  The 'label' record will be ignored, and instead be replaced by the diameter of the graph.
  '''
  graphs = []
  labels = []
  metadata = []
  with multiprocessing.Pool() as pool:
    for graph, label, smiles_record in tqdm(pool.imap(process_smiles_record, smiles_records), total=len(smiles_records)):
      graphs.append(graph)
      labels.append(label)
      metadata.append(smiles_record)
  return GraphDataset(graphs=graphs, 
                      labels=labels, 
                      node_variables=atom_features, 
                      edge_variables=bond_features, 
                      metadata=metadata)
  
dataset = make_graph_shortest_path_dataset([{'smiles': 'c1ccccc1', 'label':1},{'smiles':'OS(=O)(=O)O', 'label': 0}])
training_shortest_path_dataset = make_graph_shortest_path_dataset(training_smiles_records)
dev_shortest_path_dataset = make_graph_shortest_path_dataset(dev_smiles_records)
test_shortest_path_dataset = make_graph_shortest_path_dataset(test_smiles_records)
from torch.utils.data import DataLoader

batch_size=32
num_dataloader_workers=2 # The colab instances are very limited in number of cpus

training_shortest_path_dataloader = DataLoader(training_shortest_path_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=True, 
                                          num_workers=num_dataloader_workers, 
                                          collate_fn=collate_graph_batch)
dev_shortest_path_dataloader = DataLoader(dev_shortest_path_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
test_shortest_path_dataloader = DataLoader(test_shortest_path_dataset, 
                                     batch_size=batch_size, 
                                     shuffle=False, 
                                     num_workers=num_dataloader_workers, 
                                     drop_last=False, 
                                     collate_fn=collate_graph_batch)
torch.manual_seed(1729)
d_model = 16
basic_encoder_config = BasicTransformerConfig(d_model=d_model, 
                                      n_layers=2, 
                                      ffn_dim=16,
                                      head_dim=8,
                                      layer_normalization=True,
                                      dropout_rate=0.1,
                                      residual_connections=True)
basic_transformer = BasicTransformerEncoder(config=basic_encoder_config, 
                                            continuous_node_variables=dataset.continuous_node_variables,
                                            categorical_node_variables=dataset.categorical_node_variables,
                                            continuous_edge_variables=dataset.continuous_edge_variables,
                                            categorical_edge_variables=dataset.categorical_edge_variables,
                                            layer_type=EdgeAttributesTransformerLayer)

head_config = GraphPredictionHeadConfig(d_model=d_model, ffn_dim=32, pooling_type='sum')
prediction_head = GraphPredictionHead(input_dim=d_model, output_dim=1, config=head_config)

model = GraphPredictionNeuralNetwork(basic_transformer, prediction_head)

loss_fn = BCEWithLogitsLoss()
trainer = Trainer(model=model, 
                  loss_fn=loss_fn, 
                  training_dataloader=training_shortest_path_dataloader,
                  dev_dataloader=dev_shortest_path_dataloader)

Slow training

You will notice that this model is much slower to train than the previous. That’s not because of changes to the neural network (you can see that it’s exactly the same as before) but because the input is now much more complex. Each edge now has max_path_length as many features as before. This makes the dataloading and embedding steps more time consuming.

In particular, the Colab instances we’re using only have 2 processors, so we can’t get much help from doing batch preprocessing in parallel. When training on proper hardware, this time would be hidden due to multiprocessing.

trainer.train(1)

A note on efficiency

We’ve seen how we can effectively extend our graph neural network to not only use the neighbourhood of aggregation, and this allows us to embed a lot of domain knowledge into how we represent our problems.

The downside is that this method scales poorly with the size of our graphs. Both computation time and memory demand will scale quadratically with the size of the input.

With a regular GNN, we can make the neighbourhood aggregation sparse, and this is what practical GNN frameworks such as PyTorch Geometric do.

In the field of NLP, much research is dedicated to making the transformer architecture more efficient and we will likely see developments on how this method can be more efficient in the future.

Task

Experiment with this Path-augmented Transformer. Can you do better than the GNN from the previous notebook?

Learning outcomes

In this notebook we took a deep dive into Transformers. As you can see, we’ve essentially extended the idea of a graph neural network by computing an “adjacency matrix” (the attention matrix). The choice of what function to use to compute each element of this matrix is up to us, but for graphs it’s reasonable to use any edge information between the pair of nodes.

We saw how we could also include any pairwise information, and while we used the shortes path between the nodes in this case, for graphs embedded in a euclidean space it could have been the euclidean distance or any other information we might compute from a pair of node features.

Important concepts

  • Attention

  • Function on pair of nodes

  • Shortest path

What about this multi-head attention?

When talking about Transformers, a lot of time is often given to the multi-head self-attention. In this notebook we haven’t covered that, mostly because the difference between multi-head and single head self-attention is conceptually small, but quite a bit more messy to implement efficiently.

What is multi-head self-attention?

In multi-head, we effectively add parallel networks in the aggregation part. So think of it as having multiple parallel GNN layers, each performing their own MSG and AGG step. In the end, the resulting vectors of the parallel heads are combined (concatenated in the Transformer) to form the new vector for the node.

This is equivalent to using “block-diagonal” weight matrices (as well as some shuffling around of results) which is an interesting way of using sparse computation for neural networks.