Training a Neural Network in parallel using DistributedDataParallel

We will use an example of a simple recurrent neural network for sequence classification and how we can modify this to use the DistributedDataParallel feature of PyTorch. You can find the code we will be using here.

Installation

The code relies on anaconda. Create an environment by running

$ conda env create -f environment.yml

Once this is done, activate the envioronment by running

$ conda activate partorch

Now we need to make the code in the archive available. After extracting the files, go to the directory you extracted them to and run

$ pip install -e .

This installs the partorch package into the active anaconda environment and makes it globally available in the environment.

Modifying an existing network

We will mainly work with the file scripts/basic_neural_network.py. The important part from the file is replicated below:

 1visible_index, heldout_indices = train_test_split(np.arange(len(
 2    smiles_list)), stratify=labels, shuffle=True, test_size=0.1, random_state=args.random_seed)
 3
 4visible_labels = [labels[i] for i in visible_index]
 5train_indices, dev_indices = train_test_split(
 6    visible_index, stratify=visible_labels, shuffle=True, test_size=0.2, random_state=args.random_seed)
 7
 8train_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels,  indices=train_indices,
 9                                  tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers, shuffle=True)
10dev_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels,  indices=dev_indices,
11                                tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
12model_kwargs = dict(tokenizer=tokenizer, device=device)
13
14model_hparams = dict(embedding_dim=128,
15                     d_model=128,
16                     num_layers=3,
17                     bidirectional=True,
18                     dropout=0.2,
19                     learning_rate=0.001,
20                     weight_decay=0.0001)
21
22tb_writer = SummaryWriter('basic_runs')
23best_model, best_iteration = train(train_dataloader=train_dataloader, dev_dataloader=dev_dataloader, writer=tb_writer,
24                                   max_epochs=max_epochs, model_class=RNNPredictor, model_args=tuple(), model_kwargs=model_kwargs, model_hparams=model_hparams)
25
26heldout_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels,  indices=heldout_indices,
27                                    tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
28
29heldout_losses = []
30heldout_targets = []
31heldout_predictions = []
32for batch in heldout_dataloader:
33    loss, batch_targets, batch_predictions = best_model.eval_and_predict_batch(
34        batch)
35    heldout_losses.append(loss)
36    heldout_targets.extend(batch_targets)
37    heldout_predictions.extend(batch_predictions)
38
39heldout_roc_auc = roc_auc_score(heldout_targets, heldout_predictions)
40
41tb_writer.add_scalar('Loss/test', np.mean(heldout_losses), best_iteration)
42tb_writer.add_scalar('ROC_AUC/test', np.mean(heldout_roc_auc), best_iteration)
43tb_writer.add_hparams(hparam_dict=model_hparams, metric_dict={'hparam/roc_auc': heldout_roc_auc})
44print(f"Final test ROC AUC: {heldout_roc_auc}")
45tb_writer.close()

The highlighted lines show the parts we will focus on. These are the ones which we need to take into account when adding the parallelization.

Parallel semantics

Our parallel neural network will consist of multiple process running concurrently. These will be spawned from our main process but will execute the same code. To make the different processes work on different parts of the data, we differentiate them through an identifier called rank. We often wan’t to perform some step only once for the whole group, so it’s customary that we assign one of

the ranks a special importance, for convenience this is typically chosen to be the process with rank 0.

Initializing the distributed framework

We start by adding distributed functionality, the code we want to execute in parallel is wrapped in a function, here called distributed_training(), which will be the entry point for all spawned processes. We use pytorch’s multiprocessing package to spawn processes with this function as a target. We also create a dictionary with all the arguments our training function will need. The function will be supplied with the rank of the process from the torch.multiprocessing.spawn() function, but we also supply the total size of the process group for convenience.

 1    visible_index, heldout_indices = train_test_split(np.arange(len(
 2        smiles_list)), stratify=labels, shuffle=True, test_size=0.1, random_state=args.random_seed)
 3    visible_labels = [labels[i] for i in visible_index]
 4    train_indices, dev_indices = train_test_split(
 5        visible_index, stratify=visible_labels, shuffle=True, test_size=0.2, random_state=args.random_seed)
 6
 7    world_size = torch.cuda.device_count()
 8
 9    distributed_kwargs = dict(tokenizer=tokenizer,
10                            smiles_list=smiles_list, labels=labels, train_indices=train_indices, batch_size=batch_size,
11                            dev_indices=dev_indices, heldout_indices=heldout_indices, max_epochs=max_epochs, backend='nccl')
12
13    mp.spawn(distributed_training,
14            args=(world_size, distributed_kwargs),
15            join=True, nprocs=world_size)

The distributed training

We need to define the distributed_training() function and start with something like this:

 1def distributed_training(rank, world_size, kwargs):
 2        dist.init_process_group(
 3            kwargs['backend'], rank=rank, world_size=world_size)
 4
 5        device = torch.device(f'cuda:{rank}')
 6
 7        smiles_list, labels = kwargs['smiles_list'], kwargs['labels']
 8        tokenizer = kwargs['tokenizer']
 9        train_indices, dev_indices, heldout_indices = kwargs[
10            'train_indices'], kwargs['dev_indices'], kwargs['heldout_indices']
11        batch_size, max_epochs = kwargs['batch_size'], kwargs['max_epochs']

Most of this code is just unpacking the parameters we gave in the kwargs argument, but the vital part is the call to dist.init_process_group(). This is what actually sets up the current process as part of the process group. There’s a lot of machinery beneath this which we will not cover in this workshop.

One important question is how pytorch should communicate between the processes, and the call to init_process_group` is where we specify this. There are multiple backends which can be used for the interprocess communication, but the recommended one when training on multiple GPUs is ‘nccl’, which is developed by NVIDIA, and is what we’ll use in this workshop.

We also set the device at this point. A GPU may only be used by one process, here we instantiate a device reference using the rank of the process. If you need to limit your program to only use a subset of your GPUs, you can set the environmental variable CUDA_VISIBLE_GPUS=id1[,id2] before starting the script.

To simplify setting up the underlying process group, pytorch supplies a convenience script torchrun which can be used to inform the backend where the master process is located which is used to coordinate the processes.

We can test our script by running:

$ torchrun --master_port 31133 scripts/basic_neural_network_ddp.py dataset/BBBP.csv

This starts the script with some underlying environmental variables set which allows the process group to coordinate, in particular we tell it to use a specific port for the master process (the arbitrary 31133 argument to –master_port). We might need to set this port to different values if we’re running multiple parallel training at the same time.

We can also use torchrun to manually spawn multiple processes at different compute nodes, in that case we also tell the program at what IP adress to find our master node by suppliying a --master_addr argument.

Now we’re ready to implement more of distributed_training(). The main goal of our data-parallel training is to let the different processes work on different parts of the batch. This means that we need to partition our data based on what process is running the code. Here’s the outline of what we’ll implement next:

 1def distributed_training(rank, world_size, kwargs):
 2    dist.init_process_group(kwargs['backend'], rank=rank, world_size=world_size)
 3
 4    device = torch.device(f'cuda:{rank}')
 5
 6    smiles_list, labels = kwargs['smiles_list'], kwargs['labels']
 7    tokenizer = kwargs['tokenizer']
 8    train_indices, dev_indices, heldout_indices = kwargs[
 9        'train_indices'], kwargs['dev_indices'], kwargs['heldout_indices']
10    batch_size, max_epochs = kwargs['batch_size'], kwargs['max_epochs']
11
12train_dataloader = get_ddp_dataloader(rank=rank, world_size=world_size,
13                                      smiles_list=smiles_list,
14                                      labels=labels,  indices=train_indices,
15                                      tokenizer=tokenizer, batch_size=batch_size, shuffle=True)
16dev_dataloader = None
17if rank == 0:
18    # We will only do the evaluations on the rank 0 process, so we don't have to pass predictions around
19    dev_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels,  indices=dev_indices,
20                                    tokenizer=tokenizer, batch_size=batch_size)
21
22    model_kwargs = dict(tokenizer=tokenizer)
23
24    model_hparams = dict(embedding_dim=128,
25                        d_model=128,
26                        num_layers=3,
27                        bidirectional=True ,
28                        dropout=0.2,
29                        learning_rate=0.001,
30                        weight_decay=0.0001)
31
32    tb_writer = SummaryWriter('basic_runs', filename_suffix=f'rank{rank}')
33
34    best_model, best_iteration = train_ddp(train_dataloader=train_dataloader,
35                                        dev_dataloader=dev_dataloader,
36                                        writer=tb_writer,
37                                        max_epochs=max_epochs,
38                                        device=device,
39                                        model_class=RNNPredictor,
40                                        model_args=tuple(),
41                                        model_kwargs=model_kwargs,
42                                        model_hparams=model_hparams)
43
44    if rank == 0:
45        test_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels,  indices=heldout_indices,
46                                        tokenizer=tokenizer, batch_size=batch_size)
47        heldout_losses = []
48        heldout_targets = []
49        heldout_predictions = []
50        loss_fn = nn.BCEWithLogitsLoss()
51        for batch in test_dataloader:
52            with torch.no_grad():
53                sequence_batch, lengths, labels = batch
54                logit_prediction = best_model(sequence_batch.to(best_model.device), lengths)
55                loss = loss_fn(logit_prediction.squeeze(), labels.to(best_model.device))
56                prob_predictions = torch.sigmoid(logit_prediction)
57            heldout_losses.append(loss.item())
58            heldout_targets.extend(labels.cpu().numpy())
59            heldout_predictions.extend(prob_predictions.cpu().numpy())
60
61        heldout_roc_auc = roc_auc_score(heldout_targets, heldout_predictions)
62
63        tb_writer.add_scalar(
64            'Loss/test', np.mean(heldout_losses), best_iteration)
65        tb_writer.add_scalar(
66            'ROC_AUC/test', np.mean(heldout_roc_auc), best_iteration)
67        tb_writer.add_hparams(hparam_dict=model_hparams, metric_dict={
68                            'hparam/roc_auc': heldout_roc_auc})
69        print(f"Final test ROC AUC: {heldout_roc_auc}")
70    tb_writer.close()

We will go through these three highlighted block in order.

Distributed data loaders

First we will have a look at get_ddp_dataloader next to get_dataloader:

 1def get_dataloader(*, smiles_list, labels, tokenizer, batch_size, num_workers=0, indices=None, shuffle=False):
 2    if indices is not None:
 3        smiles_list = [smiles_list[i] for i in indices]
 4        labels = [labels[i] for i in indices]
 5    smiles_dataset = SmilesDataset(smiles_list, labels, tokenizer)
 6    dataloader = DataLoader(smiles_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_function, num_workers=num_workers, drop_last=False)
 7    return dataloader
 8
 9def get_ddp_dataloader(*, rank, world_size, smiles_list, labels, tokenizer, batch_size, num_workers=0, indices=None, shuffle=False):
10    if indices is not None:
11        smiles_list = [smiles_list[i] for i in indices]
12        labels = [labels[i] for i in indices]
13    smiles_dataset = SmilesDataset(smiles_list, labels, tokenizer)
14    sampler = DistributedSampler(smiles_dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=False)
15    dataloader = DataLoader(smiles_dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_function, num_workers=num_workers)
16    return dataloader

Conveniently, pytorch already has the functionality we need to split our batches in a distributed setting. By telling the DataLoader to use a DistributedSampler with appropriate arguments for rank and world size, the dataloader instantiated in the current process will get its own dedicated part of the dataset to work on.

Distributed optimization

Now that we’ve set up partitioned data loaders in the different processes, we will register our model with the DistributedDataParallel so that our optimization will be distributed over our processes. Let’s have a look at the old training vs. updated training loop:

 1def train(*, train_dataloader, dev_dataloader, writer, max_epochs, model_class, model_args=None, model_hparams=None, model_kwargs=None):
 2    if model_args is None:
 3        model_args = tuple()
 4    if model_kwargs is None:
 5        model_kwargs = dict()
 6    if model_hparams is None:
 7        model_hparams = dict()
 8
 9    model = model_class(*model_args, **model_kwargs, **model_hparams)
10
11    best_roc_auc = 0
12    best_model = None
13    best_iteration = 0
14    iteration = 0
15
16    learning_rate = model_hparams['learning_rate']
17    weight_decay = model_hparams['weight_decay']
18    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
19    loss_fn = nn.BCEWithLogitsLoss()
20    for e in trange(max_epochs, desc='epoch'):
21        training_losses = []
22        dev_losses = []
23        dev_targets = []
24        dev_predictions = []
25
26        model.train()
27        for batch in tqdm(train_dataloader, desc="Training batch"):
28            optimizer.zero_grad()
29            sequence_batch, lengths, labels = batch
30            logit_prediction = model(sequence_batch.to(model.device), lengths)
31            loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
32            loss.backward()
33            optimizer.step()
34
35            writer.add_scalar('Loss/train', loss.item(), iteration)
36            training_losses.append(loss.item())
37            iteration += 1
38
39        model.eval()
40        for batch in tqdm(dev_dataloader, desc="Dev batch"):
41            with torch.no_grad():
42                sequence_batch, lengths, labels = batch
43                logit_prediction = model(sequence_batch.to(model.device), lengths)
44                loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
45                prob_predictions = torch.sigmoid(logit_prediction)
46
47            dev_losses.append(loss.item())
48            dev_targets.extend(labels.cpu().numpy())
49            dev_predictions.extend(prob_predictions.cpu().numpy())
50
51        dev_roc_auc = roc_auc_score(dev_targets, dev_predictions)
52
53        writer.add_scalar('Loss/dev', np.mean(dev_losses), iteration)
54        writer.add_scalar('ROC_AUC/dev', dev_roc_auc, iteration)
55        print(f"Training loss {np.mean(training_losses)}\tDev loss: {np.mean(dev_losses)}\tDev ROC AUC:{dev_roc_auc}")
56
57        if dev_roc_auc > best_roc_auc:
58            best_roc_auc = dev_roc_auc
59            best_model = deepcopy(model)
60            best_model.recurrent_layers.flatten_parameters()  # After the deepcopy, the weight matrices are not necessarily in contiguous memory, this fixes that issue
61            best_iteration = iteration
62
63    return best_model, best_iteration
 1def train_ddp(*, train_dataloader, dev_dataloader, writer, max_epochs, model_class, device, model_args=None, model_hparams=None, model_kwargs=None):
 2    if model_args is None:
 3        model_args = tuple()
 4    if model_kwargs is None:
 5        model_kwargs = dict()
 6    if model_hparams is None:
 7        model_hparams = dict()
 8
 9    model = model_class(*model_args, **model_kwargs, device=device, **model_hparams)
10    ddp_model = DistributedDataParallel(model)
11
12    best_roc_auc = 0
13    best_model = None
14    best_iteration = 0
15    iteration = 0
16
17    learning_rate = model_hparams['learning_rate']
18    weight_decay = model_hparams['weight_decay']
19    optimizer = AdamW(ddp_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
20    loss_fn = nn.BCEWithLogitsLoss()
21    for e in range(max_epochs):
22        training_losses = []
23        dev_losses = []
24        dev_targets = []
25        dev_predictions = []
26
27        ddp_model.train()
28        for batch in train_dataloader:
29            optimizer.zero_grad()
30            sequence_batch, lengths, labels = batch
31            logit_prediction = ddp_model(sequence_batch.to(model.device), lengths)
32            loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
33            loss.backward()
34            optimizer.step()
35
36            writer.add_scalar('Loss/train', loss.item(), iteration)
37            training_losses.append(loss.item())
38            iteration += 1
39
40        if dist.get_rank() == 0:
41            ddp_model.eval()
42            for batch in dev_dataloader:
43                with torch.no_grad():
44                    sequence_batch, lengths, labels = batch
45                    logit_prediction = ddp_model(sequence_batch.to(model.device), lengths)
46                    loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
47                    prob_predictions = torch.sigmoid(logit_prediction)
48
49                dev_losses.append(loss.item())
50                dev_targets.extend(labels.cpu().numpy())
51                dev_predictions.extend(prob_predictions.cpu().numpy())
52
53            dev_roc_auc = roc_auc_score(dev_targets, dev_predictions)
54
55            writer.add_scalar('Loss/dev', np.mean(dev_losses), iteration)
56            writer.add_scalar('ROC_AUC/dev', dev_roc_auc, iteration)
57            print(f"Training loss {np.mean(training_losses)}\tDev loss: {np.mean(dev_losses)}\tDev ROC AUC:{dev_roc_auc}")
58
59            if dev_roc_auc > best_roc_auc:
60                best_roc_auc = dev_roc_auc
61                best_model = deepcopy(model)
62                best_model.recurrent_layers.flatten_parameters()  # After the deepcopy, the weight matrices are not necessarily in contiguous memory, this fixes that issue
63                best_iteration = iteration
64
65return best_model, best_iteration

If you compare the two code parts, you can see that we’re basically just wrapping our model in an DistributedDataParallel object, which gives us a new model we call ddp_model. We subsequently replace the calls to model with ddp_model which is all we need to do. The optimizer will do the right thing, synchronizing the gradients across worker processes, through it’s reference to ddp_model.parameters().

Centralizing evaluation

Note that we only run the evaluation on the dev set and update the best_model copy at the process with rank=0. The reason for this is that we don’t want to have to send results from the predictions around.

This is also what we do in the final block of the distributed_training function, we only perform the final test set evaluation at the process with rank 0.

Running the code

We have now completed our augmentation of the model and can run it using torchrun:

$ torchrun --master_port 31133 scripts/basic_neural_network_ddp.py dataset/BBBP.csv