There are several strategies used to train a deep learning model with multi devices. In order to train a model across multiple devices, deep learning frameworks provide some features for distributed training such as:
- Data Parallelism
- Model Parallelism
- Pipeline Parallelism
Each parallelism scheme has pros and cons, and engineers should decide among these to efficiently exploit their devices.
Data Parallelism is well-known distributed method for training deep learning model. The notion of data parallelism is not only in deep learning domain but in plenty of other domains. SIMD instructions process multiple data simultaneously within one instruction, which is one of the data parallelism. Also, SPMD programming model supports engineers to effectively do parallel programming. Data parallelism with multiple devices is known as batch-splitting meaning that the task is splited into subtasks and each device conducts a subtask. For example, with (256, 32, 32, 3)-shaped input and 4 GPUs, it is easy to divide input into 4 (64, 32, 32, 3)-shaped inputs because there is no dependence among batch axes in common deep learning task.
Of course, layers like Batch Normalization have to be synchronized across all subtasks so that means and variances are the same across multiple devices. We will going to talk about this later.
The implementation of data parallelism varies. Here I introduce common concept and algorithm of batch-splitting.
- Copy all parameters to each device.
- For each iteration, split the training batch into sub-batches.
- Distribute one sub-batch for one device.
- Each device computes the forward and backward passes on its-batch.
- Sum all the gradients on devices and distribute the sum.
- Update the model parameters.
Although data parallelism is dominant strategy for training on multiple devices, it suffers from the inability to train very large models due to memory constraints of GPU.
In Data Parallelism
In Model Parallelism
In Pipeline Parallelism
Frameworks for Parallelism
- PyTorch Distributed: Experiences on Accelerating Data Parallel Training
- PyTorch Distributed Overview
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- Mesh-TensorFlow: Deep Learning for Supercomputers
- GPipe: Easy Scaling with Micro-Batch Pipeline Parallelism
- PipeDream: Fast and Efficient Pipeline Parallel DNN Training