Training a Neural Network in parallel using DistributedDataParallel
We will use an example of a simple recurrent neural network for sequence
classification and how we can modify this to use the DistributedDataParallel
feature of PyTorch. You can find the code we will be using
here
.
Installation
The code relies on anaconda. Create an environment by running
$ conda env create -f environment.yml
Once this is done, activate the envioronment by running
$ conda activate partorch
Now we need to make the code in the archive available. After extracting the files, go to the directory you extracted them to and run
$ pip install -e .
This installs the partorch package into the active anaconda environment and makes it globally available in the environment.
Modifying an existing network
We will mainly work with the file scripts/basic_neural_network.py
. The
important part from the file is replicated below:
1visible_index, heldout_indices = train_test_split(np.arange(len(
2 smiles_list)), stratify=labels, shuffle=True, test_size=0.1, random_state=args.random_seed)
3
4visible_labels = [labels[i] for i in visible_index]
5train_indices, dev_indices = train_test_split(
6 visible_index, stratify=visible_labels, shuffle=True, test_size=0.2, random_state=args.random_seed)
7
8train_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels, indices=train_indices,
9 tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers, shuffle=True)
10dev_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels, indices=dev_indices,
11 tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
12model_kwargs = dict(tokenizer=tokenizer, device=device)
13
14model_hparams = dict(embedding_dim=128,
15 d_model=128,
16 num_layers=3,
17 bidirectional=True,
18 dropout=0.2,
19 learning_rate=0.001,
20 weight_decay=0.0001)
21
22tb_writer = SummaryWriter('basic_runs')
23best_model, best_iteration = train(train_dataloader=train_dataloader, dev_dataloader=dev_dataloader, writer=tb_writer,
24 max_epochs=max_epochs, model_class=RNNPredictor, model_args=tuple(), model_kwargs=model_kwargs, model_hparams=model_hparams)
25
26heldout_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels, indices=heldout_indices,
27 tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
28
29heldout_losses = []
30heldout_targets = []
31heldout_predictions = []
32for batch in heldout_dataloader:
33 loss, batch_targets, batch_predictions = best_model.eval_and_predict_batch(
34 batch)
35 heldout_losses.append(loss)
36 heldout_targets.extend(batch_targets)
37 heldout_predictions.extend(batch_predictions)
38
39heldout_roc_auc = roc_auc_score(heldout_targets, heldout_predictions)
40
41tb_writer.add_scalar('Loss/test', np.mean(heldout_losses), best_iteration)
42tb_writer.add_scalar('ROC_AUC/test', np.mean(heldout_roc_auc), best_iteration)
43tb_writer.add_hparams(hparam_dict=model_hparams, metric_dict={'hparam/roc_auc': heldout_roc_auc})
44print(f"Final test ROC AUC: {heldout_roc_auc}")
45tb_writer.close()
The highlighted lines show the parts we will focus on. These are the ones which we need to take into account when adding the parallelization.
Parallel semantics
Our parallel neural network will consist of multiple process running concurrently. These will be spawned from our main process but will execute the same code. To make the different processes work on different parts of the data, we differentiate them through an identifier called rank. We often wan’t to perform some step only once for the whole group, so it’s customary that we assign one of
the ranks a special importance, for convenience this is typically chosen to be the process with rank 0.
Initializing the distributed framework
We start by adding distributed functionality, the code we want to execute in
parallel is wrapped in a function, here called distributed_training()
,
which will be the entry point for all spawned processes. We use pytorch’s
multiprocessing
package to spawn processes with this function as a
target. We also create a dictionary with all the arguments our training function will need.
The function will be supplied with the rank of the process from the
torch.multiprocessing.spawn()
function, but we also supply the
total size of the process group for convenience.
1 visible_index, heldout_indices = train_test_split(np.arange(len(
2 smiles_list)), stratify=labels, shuffle=True, test_size=0.1, random_state=args.random_seed)
3 visible_labels = [labels[i] for i in visible_index]
4 train_indices, dev_indices = train_test_split(
5 visible_index, stratify=visible_labels, shuffle=True, test_size=0.2, random_state=args.random_seed)
6
7 world_size = torch.cuda.device_count()
8
9 distributed_kwargs = dict(tokenizer=tokenizer,
10 smiles_list=smiles_list, labels=labels, train_indices=train_indices, batch_size=batch_size,
11 dev_indices=dev_indices, heldout_indices=heldout_indices, max_epochs=max_epochs, backend='nccl')
12
13 mp.spawn(distributed_training,
14 args=(world_size, distributed_kwargs),
15 join=True, nprocs=world_size)
The distributed training
We need to define the distributed_training()
function and start with something like this:
1def distributed_training(rank, world_size, kwargs):
2 dist.init_process_group(
3 kwargs['backend'], rank=rank, world_size=world_size)
4
5 device = torch.device(f'cuda:{rank}')
6
7 smiles_list, labels = kwargs['smiles_list'], kwargs['labels']
8 tokenizer = kwargs['tokenizer']
9 train_indices, dev_indices, heldout_indices = kwargs[
10 'train_indices'], kwargs['dev_indices'], kwargs['heldout_indices']
11 batch_size, max_epochs = kwargs['batch_size'], kwargs['max_epochs']
Most of this code is just unpacking the parameters we gave in the kwargs
argument,
but the vital part is the call to dist.init_process_group()
. This is what actually
sets up the current process as part of the process group. There’s a lot of machinery
beneath this which we will not cover in this workshop.
One important question is how pytorch should communicate between the processes,
and the call to init_process_group`
is where we specify this. There are
multiple backends which can be used for the interprocess communication, but the
recommended one when training on multiple GPUs is ‘nccl’, which is developed by
NVIDIA, and is what we’ll use in this workshop.
We also set the device at this point. A GPU may only be used by one process, here
we instantiate a device reference using the rank of the process. If you need to
limit your program to only use a subset of your GPUs, you can set the environmental variable
CUDA_VISIBLE_GPUS=id1[,id2]
before starting the script.
To simplify setting up the underlying process group, pytorch supplies a convenience script
torchrun
which can be used to inform the backend where the master process is located
which is used to coordinate the processes.
We can test our script by running:
$ torchrun --master_port 31133 scripts/basic_neural_network_ddp.py dataset/BBBP.csv
This starts the script with some underlying environmental variables set which allows the process group to coordinate, in particular we tell it to use a specific port for the master process (the arbitrary 31133 argument to –master_port). We might need to set this port to different values if we’re running multiple parallel training at the same time.
We can also use torchrun
to manually spawn multiple processes at different compute nodes, in
that case we also tell the program at what IP adress to find our master node by suppliying a
--master_addr
argument.
Now we’re ready to implement more of distributed_training()
. The main goal of our data-parallel training
is to let the different processes work on different parts of the batch. This means that we need to
partition our data based on what process is running the code. Here’s the outline of what we’ll
implement next:
1def distributed_training(rank, world_size, kwargs):
2 dist.init_process_group(kwargs['backend'], rank=rank, world_size=world_size)
3
4 device = torch.device(f'cuda:{rank}')
5
6 smiles_list, labels = kwargs['smiles_list'], kwargs['labels']
7 tokenizer = kwargs['tokenizer']
8 train_indices, dev_indices, heldout_indices = kwargs[
9 'train_indices'], kwargs['dev_indices'], kwargs['heldout_indices']
10 batch_size, max_epochs = kwargs['batch_size'], kwargs['max_epochs']
11
12train_dataloader = get_ddp_dataloader(rank=rank, world_size=world_size,
13 smiles_list=smiles_list,
14 labels=labels, indices=train_indices,
15 tokenizer=tokenizer, batch_size=batch_size, shuffle=True)
16dev_dataloader = None
17if rank == 0:
18 # We will only do the evaluations on the rank 0 process, so we don't have to pass predictions around
19 dev_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels, indices=dev_indices,
20 tokenizer=tokenizer, batch_size=batch_size)
21
22 model_kwargs = dict(tokenizer=tokenizer)
23
24 model_hparams = dict(embedding_dim=128,
25 d_model=128,
26 num_layers=3,
27 bidirectional=True ,
28 dropout=0.2,
29 learning_rate=0.001,
30 weight_decay=0.0001)
31
32 tb_writer = SummaryWriter('basic_runs', filename_suffix=f'rank{rank}')
33
34 best_model, best_iteration = train_ddp(train_dataloader=train_dataloader,
35 dev_dataloader=dev_dataloader,
36 writer=tb_writer,
37 max_epochs=max_epochs,
38 device=device,
39 model_class=RNNPredictor,
40 model_args=tuple(),
41 model_kwargs=model_kwargs,
42 model_hparams=model_hparams)
43
44 if rank == 0:
45 test_dataloader = get_dataloader(smiles_list=smiles_list, labels=labels, indices=heldout_indices,
46 tokenizer=tokenizer, batch_size=batch_size)
47 heldout_losses = []
48 heldout_targets = []
49 heldout_predictions = []
50 loss_fn = nn.BCEWithLogitsLoss()
51 for batch in test_dataloader:
52 with torch.no_grad():
53 sequence_batch, lengths, labels = batch
54 logit_prediction = best_model(sequence_batch.to(best_model.device), lengths)
55 loss = loss_fn(logit_prediction.squeeze(), labels.to(best_model.device))
56 prob_predictions = torch.sigmoid(logit_prediction)
57 heldout_losses.append(loss.item())
58 heldout_targets.extend(labels.cpu().numpy())
59 heldout_predictions.extend(prob_predictions.cpu().numpy())
60
61 heldout_roc_auc = roc_auc_score(heldout_targets, heldout_predictions)
62
63 tb_writer.add_scalar(
64 'Loss/test', np.mean(heldout_losses), best_iteration)
65 tb_writer.add_scalar(
66 'ROC_AUC/test', np.mean(heldout_roc_auc), best_iteration)
67 tb_writer.add_hparams(hparam_dict=model_hparams, metric_dict={
68 'hparam/roc_auc': heldout_roc_auc})
69 print(f"Final test ROC AUC: {heldout_roc_auc}")
70 tb_writer.close()
We will go through these three highlighted block in order.
Distributed data loaders
First we will have a look at get_ddp_dataloader
next to get_dataloader
:
1def get_dataloader(*, smiles_list, labels, tokenizer, batch_size, num_workers=0, indices=None, shuffle=False):
2 if indices is not None:
3 smiles_list = [smiles_list[i] for i in indices]
4 labels = [labels[i] for i in indices]
5 smiles_dataset = SmilesDataset(smiles_list, labels, tokenizer)
6 dataloader = DataLoader(smiles_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_function, num_workers=num_workers, drop_last=False)
7 return dataloader
8
9def get_ddp_dataloader(*, rank, world_size, smiles_list, labels, tokenizer, batch_size, num_workers=0, indices=None, shuffle=False):
10 if indices is not None:
11 smiles_list = [smiles_list[i] for i in indices]
12 labels = [labels[i] for i in indices]
13 smiles_dataset = SmilesDataset(smiles_list, labels, tokenizer)
14 sampler = DistributedSampler(smiles_dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=False)
15 dataloader = DataLoader(smiles_dataset, sampler=sampler, batch_size=batch_size, collate_fn=collate_function, num_workers=num_workers)
16 return dataloader
Conveniently, pytorch already has the functionality we need to
split our batches in a distributed setting. By telling the
DataLoader
to use a DistributedSampler
with appropriate arguments
for rank and world size, the dataloader instantiated in the current
process will get its own dedicated part of the dataset to work on.
Distributed optimization
Now that we’ve set up partitioned data loaders in the different processes, we will register our model with the DistributedDataParallel
so that our optimization will be distributed over our processes.
Let’s have a look at the old training vs. updated training loop:
1def train(*, train_dataloader, dev_dataloader, writer, max_epochs, model_class, model_args=None, model_hparams=None, model_kwargs=None):
2 if model_args is None:
3 model_args = tuple()
4 if model_kwargs is None:
5 model_kwargs = dict()
6 if model_hparams is None:
7 model_hparams = dict()
8
9 model = model_class(*model_args, **model_kwargs, **model_hparams)
10
11 best_roc_auc = 0
12 best_model = None
13 best_iteration = 0
14 iteration = 0
15
16 learning_rate = model_hparams['learning_rate']
17 weight_decay = model_hparams['weight_decay']
18 optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
19 loss_fn = nn.BCEWithLogitsLoss()
20 for e in trange(max_epochs, desc='epoch'):
21 training_losses = []
22 dev_losses = []
23 dev_targets = []
24 dev_predictions = []
25
26 model.train()
27 for batch in tqdm(train_dataloader, desc="Training batch"):
28 optimizer.zero_grad()
29 sequence_batch, lengths, labels = batch
30 logit_prediction = model(sequence_batch.to(model.device), lengths)
31 loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
32 loss.backward()
33 optimizer.step()
34
35 writer.add_scalar('Loss/train', loss.item(), iteration)
36 training_losses.append(loss.item())
37 iteration += 1
38
39 model.eval()
40 for batch in tqdm(dev_dataloader, desc="Dev batch"):
41 with torch.no_grad():
42 sequence_batch, lengths, labels = batch
43 logit_prediction = model(sequence_batch.to(model.device), lengths)
44 loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
45 prob_predictions = torch.sigmoid(logit_prediction)
46
47 dev_losses.append(loss.item())
48 dev_targets.extend(labels.cpu().numpy())
49 dev_predictions.extend(prob_predictions.cpu().numpy())
50
51 dev_roc_auc = roc_auc_score(dev_targets, dev_predictions)
52
53 writer.add_scalar('Loss/dev', np.mean(dev_losses), iteration)
54 writer.add_scalar('ROC_AUC/dev', dev_roc_auc, iteration)
55 print(f"Training loss {np.mean(training_losses)}\tDev loss: {np.mean(dev_losses)}\tDev ROC AUC:{dev_roc_auc}")
56
57 if dev_roc_auc > best_roc_auc:
58 best_roc_auc = dev_roc_auc
59 best_model = deepcopy(model)
60 best_model.recurrent_layers.flatten_parameters() # After the deepcopy, the weight matrices are not necessarily in contiguous memory, this fixes that issue
61 best_iteration = iteration
62
63 return best_model, best_iteration
1def train_ddp(*, train_dataloader, dev_dataloader, writer, max_epochs, model_class, device, model_args=None, model_hparams=None, model_kwargs=None):
2 if model_args is None:
3 model_args = tuple()
4 if model_kwargs is None:
5 model_kwargs = dict()
6 if model_hparams is None:
7 model_hparams = dict()
8
9 model = model_class(*model_args, **model_kwargs, device=device, **model_hparams)
10 ddp_model = DistributedDataParallel(model)
11
12 best_roc_auc = 0
13 best_model = None
14 best_iteration = 0
15 iteration = 0
16
17 learning_rate = model_hparams['learning_rate']
18 weight_decay = model_hparams['weight_decay']
19 optimizer = AdamW(ddp_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
20 loss_fn = nn.BCEWithLogitsLoss()
21 for e in range(max_epochs):
22 training_losses = []
23 dev_losses = []
24 dev_targets = []
25 dev_predictions = []
26
27 ddp_model.train()
28 for batch in train_dataloader:
29 optimizer.zero_grad()
30 sequence_batch, lengths, labels = batch
31 logit_prediction = ddp_model(sequence_batch.to(model.device), lengths)
32 loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
33 loss.backward()
34 optimizer.step()
35
36 writer.add_scalar('Loss/train', loss.item(), iteration)
37 training_losses.append(loss.item())
38 iteration += 1
39
40 if dist.get_rank() == 0:
41 ddp_model.eval()
42 for batch in dev_dataloader:
43 with torch.no_grad():
44 sequence_batch, lengths, labels = batch
45 logit_prediction = ddp_model(sequence_batch.to(model.device), lengths)
46 loss = loss_fn(logit_prediction.squeeze(), labels.to(model.device))
47 prob_predictions = torch.sigmoid(logit_prediction)
48
49 dev_losses.append(loss.item())
50 dev_targets.extend(labels.cpu().numpy())
51 dev_predictions.extend(prob_predictions.cpu().numpy())
52
53 dev_roc_auc = roc_auc_score(dev_targets, dev_predictions)
54
55 writer.add_scalar('Loss/dev', np.mean(dev_losses), iteration)
56 writer.add_scalar('ROC_AUC/dev', dev_roc_auc, iteration)
57 print(f"Training loss {np.mean(training_losses)}\tDev loss: {np.mean(dev_losses)}\tDev ROC AUC:{dev_roc_auc}")
58
59 if dev_roc_auc > best_roc_auc:
60 best_roc_auc = dev_roc_auc
61 best_model = deepcopy(model)
62 best_model.recurrent_layers.flatten_parameters() # After the deepcopy, the weight matrices are not necessarily in contiguous memory, this fixes that issue
63 best_iteration = iteration
64
65return best_model, best_iteration
If you compare the two code parts, you can see that we’re basically just wrapping
our model in an DistributedDataParallel
object, which gives us a new model
we call ddp_model
.
We subsequently replace the calls to model
with ddp_model
which
is all we need to do. The optimizer will do the right thing, synchronizing the
gradients across worker processes, through it’s reference
to ddp_model.parameters()
.
Centralizing evaluation
Note that we only run the evaluation on the dev set and update the best_model
copy
at the process with rank=0. The reason for this is that we don’t want to
have to send results from the predictions around.
This is also what we do in the final block of the distributed_training
function,
we only perform the final test set evaluation at the process with rank 0.
Running the code
We have now completed our augmentation of the model and can run it using torchrun
:
$ torchrun --master_port 31133 scripts/basic_neural_network_ddp.py dataset/BBBP.csv