Training machine learning models faster with Dask

Machine learning (ML) relies on stochastic algorithms, all of which rely on gradient approximations with "batch size" examples. Growing the batch size as the optimization proceeds is a simple and usable method to reduce the training time, provided that the number of workers grows with the batch size. In this work, we provide a package that trains PyTorch models on Dask clusters, and can grow the batch size if desired. Our simulations indicate that for a particular model that uses GPUs for a popular image classification task, the training time can be reduced from about 120 minutes with standard SGD to 45 minutes with a variable batch size method.


Introduction
Training deep machine learning models takes a long time. For example, training a popular image classification model [RRSS19] to reasonable accuracy takes "around 17 hours" on Google servers. 1 Another example includes training an NLP model for 10 days on 8 high-end GPUs [RNSS18]. 2 Notably, the number of floating point operations (FLOPs) required for "the largest AI training runs" doubles every 3.4 months. 3 Model training is fundamentally an optimization problem: it tries to find a model w w w that minimizes a loss function F: w w w = arg min where there are n examples in the training set, and each example is represented by z z z i . For classification, z z z i = (x x x i , y i ) for a label y i and feature vector x x x i . The loss function F is the mean of the loss f over different examples. To compute this minimization for large scale machine learning, stochastic gradient descent (SGD) or a variant thereof is used [BCN18]. SGD is iterative, and the model update at each step k is computed via g g g(w w w k ; z z z i s ) where g g g is the gradient of the loss function f for some batch size B k ≥ 1, i s is chosen uniformly at random and γ k > 0 is 2. See OpenAI's blog post "Improving Language Understanding with Unsupervised Learning." 3. See OpenAI's blog post "AI and Compute." the learning rate or step size. The objective function's gradient is approximated with B k examples -the gradient approximation 1 B k ∑ B k i=1 g g g(w w w k ; z z z i s ) is an unbiased estimator of the loss function F's gradient. This computation is common in the vast majority of SGD variants, and is found in popular variants like Adam [KB14], RMSprop [ZSJ + 19], Adagrad [DHS11], Adadelta [Zei12], and averaged SGD [PJ92]. Most variants make modifications to the learning rate γ k [DHS11], [Zei12], [KB14], [ZSJ + 19].
Increasing the batch size B k will reduce the number of model updates while not requiring more FLOPs or gradient computations -both empirically [SKYL17] and theoretically [Sie20]. Typically, the number of FLOPs controls the training time because training is performed with a single processor. At first, fewer model updates seems like an internal benefit that doesn't affect training time.
The benefit comes when training with multiple machines, aka a distributed system. Notably, the time required to complete a single model update is (nearly) agnostic to the batch size provided the number of workers in a distributed system grows with the batch size. In one experiment, the time to complete a model update grows by 13% despite the batch size growing by a factor of 44 [GDG + 17, Sec. 5.5]. This acceleration has also been observed with an increasing batch size schedule [SKYL17,Sec. 5.4].

Contributions
We provide software to accelerate machine learning model training, at least with certain distributed systems. For acceleration, the distributed system must be capable of assigning a different number of workers according to a fixed schedule. Specifically, this work provides the following: size (and the number of workers with it) by a factor of 44 but the time for a single model update only increases by a factor of 1.13 [GDG + 17, Sec. 5.5]. Now, let's cover related work to gain understanding of why variable batch sizes provide a benefit in a distributed system. Then, let's cover the details of our software before presenting simulations. These simulations confirm that model training can be accelerated if the number of workers grows with the batch size. Methods to workaround limitations on the number of workers will be presented.

Related work
The data flow for distributed model training involves distributing the computation of the gradient estimate, 1 B ∑ B i=1 g g g(w w w k ; z z z i ). Typically, each worker computes the gradients for B/P examples when there is a batch size of B and P machines. Then, the average of these gradients is taken and the model is updated. 6 Clearly, Amdahl's law is relevant because there are diminishing returns as the number of workers P is increased [GVY + 18]. This as referred to as "strong scaling" because the batch size is fixed and the number of workers is treated as an internal detail. By contrast, growing the amount of data with the number of workers is known as "weak scaling." Of course, relevant experiments show that weak scaling exhibits better scaling than strong scaling [QST17].

Constant batch sizes
To circumvent Amdahl's law, a common technique is to increase the batch size [ZLN + 19] alongside the learning rate [JAGG20]. Using moderately large batch sizes yields high quality results more quickly and, in practice, requires no more computation than small batch sizes, both empirically [GDG + 17] and theoretically [YPL + 18].
There are many methods to choose the best constant batch size (e.g., [GGS19], [KSL + 20]). Some methods are data dependent [YPL + 18], and others depend on the model complexity. In particular, one method uses hardware topology (e.g., network bandwidth) in a distributed system [PKK + 19].
Large constant batch sizes present generalization challenges [GDG + 17]. The generalization error is hypothesized to come from "sharp" minima, strongly influenced by the learning rate and noise in the gradient estimate [KMN + 16]. To match performance on the training dataset, careful thought must be given to hyperparameter selection [GDG + 17, Sec. 3 and 5.2]. In fact, this has motivated algorithms specifically designed for large constant batch sizes and distributed systems [JAGG20], [JSH + 18], [YGG17].

Increasing the batch size
Model quality greatly influences the amount of information in the gradient -which influences the batch size [Sie20]. For example, if models are poorly initialized, then using a large batch size has no benefit: the gradient-or direction to the optimal model-for each example will produce very similar numbers. An illustration is given in Figure 1.
Various methods to adaptively change the batch size based on model performance have been proposed [Sie20], [DYJG16], [BRH17], [BCNW12]. Of course, these methods are adaptive so 6. Related but tangential methods include methods to efficiently communicate the gradient estimates [ computing the batch size requires computation (though there are workarounds [Sie20], [BRH17]).
Convergence results have been given for adaptive batch sizes [Sie20], [BCN18], [ZYF18]. Increasing the batch size is a provably good measure that requires far fewer model updates and no more computation than standard SGD for strongly convex functions [BCN18,Ch. 5], and all function classes if the batch size is provided by an oracle [Sie20]. Convergence proofs have also been given for the passively increasing the batch size, both for strongly convex functions [BCN18, Ch. 5] and for non-convex functions [ZYF18]. Both of these methods require fewer model updates than SGD and do not increase the number of gradient computations.
Notably, a geometric batch size increase schedule has shown great empirical performance in image classification [SKYL17]. Specifically, the number of model updates required to finish training decreased by a factor of 2.2 over standard SGD [SKYL17]. Smith et al. make an observation that increasing the batch size and decreasing the learning rate both decay the optimization's "noise scale" (or variance of the model update), which has connections to simulated annealing [SKYL17]. This motivates increasing the batch size by the same factor the learning rate decays [SKYL17].
Both growing the batch size and using large constant batch sizes should require the same number of floating point operations as using standard SGD with small batch sizes to reach a particular training loss (respectively [Sie20], [BCN18] and [JAGG20], [YLR + 19], [YPL + 18]). Some proof techniques suggest that variable batch size methods mirror gradient descent [Sie20], [KNS16], so correspondingly, the implementations do not require much additional hyperparameter tuning [SKYL17].

Distributed training with Dask
We have written "AdaDamp," a software package to to train a PyTorch model with a Scikit-learn API on any Dask cluster. 7 It supports the use of constant or variable batch sizes, which fits nicely with Dask's ability to change the number of workers. 8 In this section, we will walk through the basic architecture of our software and an example usage. We will defer showing the primary benefit of our software to the experimental results.

Architecture
Our software uses a centralized synchronous parameter server and controls the data flow of the optimization with Dask (and does not rely on PyTorch's distributed support). Specifically, the following happen on every model update: 1) The master node broadcasts the model to every worker.
2) The workers calculate the gradients.
3) The workers communicate the gradients back to the master. 4) The master performs a model update with the aggregated gradients.
We use Dask to implement this data flow, which adds some overhead. 9 AdaDamp supports static batch sizes; however, there is little incentive to use AdaDamp with a static batch sizes: the 7. While our software works with a constant batch size, the native implementations work with constant batch sizes and very likely have less overhead (e.g., PyTorch Distributed [LZV + 20]). 8. https://github.com/stsievert/adadamp Fig. 1: An illustration of why the batch size should increase. Here, let's find a model w w w = [w x , w y ] that minimizes the function f (w x , w y ) = ∑ 3 i=0 (w x − x i ) 2 + (w y − y i ) 2 where x i and y i are the x and y coordinates of each datum. When closer to the optimum at model A, the gradients are more "diverse," so the magnitude and orientation of each datum's gradient varies more [YPL + 18].
native solution has PyTorch less overhead [LZV + 20], and already has a Dask wrapper. 10 The  So far, a PyTorch model and optimizer have been specified. As per the Scikit-learn API, we specify parameters for the model/optimizer with double underscores, so in our example HiddenLayer(features=10) will be created. We can set the batch size increase parameters at initialization if desired, or inside set_params. This will increase the batch size by a factor of 5 every 60 epochs, which is used in the experiments. Now, we can train: from sklearn.datasets import make_regression X, y = make_regression(n_features=10) X = torch.from_numpy(X.astype("float32")) y = torch.from_numpy(y.astype("float32")).reshape(-1, 1) est.fit(X, y)

Experiments
In this section, we present two sets of experiments. 11 Both experiments will use the same setup, a Wide-ResNet model in a "16-4" architecture [ZK16] to perform image classification on the CIFAR10 dataset [KH09]. This is a deep learning model with about 2.75 million weights that requires a GPU to train. 12 The experiments will provide evidence for the following points: 1) Increasing the batch size reduces the number of model updates. 2) The time required for model training is roughly proportional to the number of model updates (presuming the distributed system is configured correctly).
To provide evidence for these points, let's run one set of experiments that varies the batch size increase schedule. These experiments will mirror the experiments by Smith et al. [SKYL17]. Additionally, let's ensure that our software accelerates model training as the number of GPUs increase.
We train each batch size increase schedule once, and then write the historical performance to disk. This reduces the need for many GPUs, and allows us to simulate different networks and highlight the performance of Dask. That means that in our simulations, we simulate model training by having the computer sleep for an appropriate and realistic amount of time.
11. Full detail on these experiments can be found at https://github.com/ stsievert/adadamp-experiments 12. Specifically, we used a NVIDIA T4 GPU with an Amazon g4dn.xlarge instance. Training consumes 2.2GB of GPU memory with a batch size of 32, and 5.5GB with a batch size of 256.

Batch size increase
To illustrate the primary benefit of our software, let's perform several trainings that require a different number of model updates. These experiments explicitly mirror the experiments by Smith et al. [SKYL17,Sec. 5.1], which helps reduce the parameter tuning.
Largely, the same hyperparameters are used. These experiments only differ in the choice of batch size and learning rate, as shown in Figure 2. As in the Smith et al. experiments, every optimizer uses Nesterov momentum [Nes98] and the same momentum (0.9) and weight decay (0.5 · 10 −3 ). They start with the same initial learning rate (0.05), 13 and either the learning rate is decreased or the batch size increases by a specified factor (5) at particular intervals (epochs 60, 120 and 180). This means that the variance of the model update is reduced by a constant factor at each update. These different decay schedules exhibit the same performance in terms of number of epochs, which is proportional to the number of FLOPs, as shown in Figure 3. The number of FLOPs is (approximately) to the cost, at least on Amazon EC2 where the cost to rent a server tends to be proportional to the number of GPUs.
Importantly, this work focuses on increasing the number of workers with the batch size -the effect of which is hidden in Figure 3. However, the fact that the performance does not change with different schedules means that choosing a different batch size increase schedule will not require more wall-clock time if only a single worker is available. Combined with the hyperparameter similarity between the different schedules, this reduces deployment and debugging concerns. 13. These are the same as Smith et al. [SKYL17] with the exception of learning rate (which had to be reduced by a factor of 2).   If the number of workers grows with the batch size, then the number of model updates is relevant to the wall-clock time. Figure  4 shows the number of model updates and wall-clock time required to reach a model of a particular test accuracy. Of course, there is some overhead to our current framework, which is why the number of model updates does not exactly correlate with the wall-clock time required to complete training. In summary, the time required to complete training is shown in Table 1.

Future work
Architecture Fundamentally, the model weights can be either be held on a master node (centralized), or on every node (decentralized). Respectively, these storage architectures typically use point-topoint communication or an "all-reduce" communication. Both centralized [LAP +   Items (1), (3) and (4) are a large concern in our implementation. Decentralized communication has the advantage of eliminating items (1) and (4), and mitigates (3) with a smarter communication strategy (all-reduce vs. point-to-point). Item (2) is still a concern with straggler nodes [DCM + 12], but recent work has achieved "near-linear scalability with 256 GPUs" in a homogeneous computing environment [LZV + 20]. Items (2) and (5) can be avoided with asynchronous methods (e.g., [RRWN11], [ZHA16]).
That is, most of the concerns in our implementation will be resolved with a distributed communication strategy. The Py-Torch distributed communication package uses a synchronous decentralized strategy, so the model is communicated to each worker and gradients are sent between workers with an all-reduce scheme [LZV + 20]  will change under different architectures and networks. The "centralized" architecture is the currently implemented architecture, and has the same numbers as "training time" in Table 4.

Simulations
We have simulated the expected gain from the work of enabling decentralized communication with two networks that use a decentralized all-reduce strategy: • decentralized-medium It assumes an a network with inter-worker bandwidth of 54Gb/s and a latency of 0.05µs.
• centralized uses a centralized communication strategy (as implemented) and the same network as decentralized-medium.
• decentralized-high has the same network as decentralized-medium but has an inter-worker bandwidth of 800Gb/s and a latency of 0.025µs.
To provide baseline performance, we also show the results with the current implementation: • centralized uses the same network as decentralized-medium but with the centralized communication scheme that is currently implemented.
decentralized-medium is most applicable for clusters that have decent bandwidth between nodes. It's also applicable to for certain cases when Amazon EC2 is used with one GPU per worker, 15 or workers have a very moderate Infiniband setup. 16 decentralized-high is a simulation of the network used by the PyTorch developers to illustrate their distributed communication [LZV + 20]. We have run simulations to illustrate the effects of these networks. Of course, changing the underlying networks does not affect the number of epochs or model updates, so Figures 3 and 4 also apply here.
A summary of how different networks affect training time is shown in Table 2. We show the training time for a particular network (decentralized-moderate) in Figure 6; decentralized-high shows similar performance as illustrated in Table 2. A visualization of 2 is shown in Figure 5. This shows how network quality affects the performance of different optimization methods in Figure 6. Clearly, the optimization method (and the maximum number of workers) is more important than the network. batch size. This simulation will use the decentralized-high network and has the advantage of removing any overhead. The results in Figure 7 show that the speedups start saturating around 128 examples/worker for the model used with a batch size of 512. Larger batch sizes will likely mirror this performancecomputation is bottleneck with this model/dataset/hardware.

Conclusion
In this work, we have provided a package to train PyTorch ML models with Dask cluster. This package reduces the amount of time required to train a model with the current centralized setup. However, it can be further accelerated by integration with PyTorch's distributed communication package as illustrated by extensive simulations. For a particular model, only 45 minutes is required for training -an improvement over the 120 minutes required with standard SGD.