Tackling the Unlimited Staleness in Federated Learning with Intertwined Data and Device Heterogeneities

Abstract

The efficiency of Federated Learning (FL) is often affected by both data and device heterogeneities. Data heterogeneity is defined as the heterogeneity of data distributions on different clients. Device heterogeneity is defined as the clients’ variant latencies in uploading their local model updates due to heterogeneous conditions of local hardware resources, and causes the problem of staleness when being addressed by asynchronous FL. Traditional schemes of tackling the impact of staleness consider data and device heterogeneities as two separate and independent aspects in FL, but this assumption is unrealistic in many practical FL scenarios where data and device heterogeneities are intertwined. In these cases, traditional schemes of weighted aggregation in FL have been proved to be ineffective, and a better approach is to convert a stale model update into a non-stale one. In this paper, we present a new FL framework that leverages the gradient inversion technique for such conversion, hence efficiently tackling unlimited staleness in clients’ model updates. Our basic idea is to use gradient inversion to get estimations of clients’ local training data from their uploaded stale model updates, and use these estimations to compute non-stale client model updates. In this way, we address the problem of possible data quality drop when using gradient inversion, while still preserving the clients’ local data privacy. We compared our approach with the existing FL strategies on mainstream datasets and models, and experiment results demonstrate that when tackling unlimited staleness, our approach can significantly improve the trained model accuracy by up to 20% and speed up the FL training progress by up to 35%.

Publication
In arXiv preprint

Background

In practical Federated Learning (FL) scenarios, it is not uncommon to witness excessive or even unlimited staleness in clients’ model updates, especially when the client devices have very limited computing power, local energy budget, or communication capabilities. Traditional asynchronous Federated Learning (AFL) solution, such as weighted aggregation, results in improper bias towards fast clients and misses important knowledge in slow clients’ model updates, when data and device heterogeneities are intertwined.

To demonstrate its impact, we conducted experiments by using the MNIST dataset on 100 clients to train a 3-layer CNN model. We set data heterogeneity as that each client only contain samples in one data class, and set device heterogeneity as a staleness of 40 epochs on clients with data samples in class 5. Results in the bottom-left figure show that, staleness will lead to large degradation of model accuracy, and using weighted aggregation will further enlarge the degradation. The bottom-right figure also shows that other techniques such as DC-ASGD become ineffective rapidly with the increase of staleness epoch value.

The impact of staleness in asynchronous Federated Learning

Methodology

We propose addressing the above limitations based on existing techniques of gradient inversion. Gradient inversion (GI) aims to recover the original training data from gradients of a model under the white box setting where the all information about the model is known. The basic idea is to minimize the difference between the trained model’s gradient and the gradient computed from the recovered data.

Overall Picture

As shown in the figure above, our proposed technique consists of three key components: 1) recovering an intermediate dataset from the received stale model update via gradient inversion to represent the distribution of the client’s training data; 2) estimating the unstale model update using the recovered dataset; and 3) deciding when to switch back to vanilla FL in the late stage of FL training, to avoid the excessive estimation error from gradient inversion.

Experiment Results

We evaluated our proposed technique in two FL scenarios. In the first scenario, all clients’ local datasets are fixed. In the second scenario, we consider a more practical FL setting, where clients’ local data is continuously updated and data distributions are variant over time. In all experiments, we consider a FL scenario with 100 clients. We assess our approach’s performance improvement by measuring the increase of model accuracy in the selected data class being affected by staleness.

FL Performance in the Fixed Data Scenario

With staleness set to 40 epochs, the figures below shows (1) the model accuracy in data class 5, using MNIST dataset to train a LeNet model, and (2) in data class 2 with CIFAR-10 dataset and ResNet-8 model.

Result in Fixed Data Scenario

Results also shows that our GI-based Estimation method could speed up the training process by achieving the same model accuracy with less training time consumption. This applies to different staleness conditions, different datasets and different amounts of data heterogeneity (as indicated by parameter $\alpha$ in the table below).

Table Results in Fixed Data Scenario

FL Performance in the Variant Data Scenario

To verify performance with continuously varied data distributions of clients’ local datasets, we use MNIST and SVHN datasets, which are for the same learning task (i.e., handwriting digit recognition) but with different feature representations. Each client’s local dataset is initialized as the MNIST dataset in the same way as in the fixed data scenario. Afterwards, during training, each client continuously replaces random data samples in its local dataset with new data samples in the SVHN dataset.

Results in the figure below show that in such variant data scenario, the model accuracy improvements by the existing FL training strategies exhibit significant fluctuations over time and stay low (<40%). in comparison, our proposed gradient inversion based estimation achieves much higher model accuracy, which is comparable to FL without staleness and 20% higher than those in existing FL schemes.

Results with variant data distributions

More results below with different amounts of staleness and rates of data variation (streaming rate) also demonstrate that our proposed method outperformed the existing FL strategies in different scenarios with different dynamics of local data patterns.

Table results with variant data distributions

Haoming Wang
Haoming Wang
PhD Student

PhD student in Electrical and Computer Engineering

Wei Gao
Wei Gao
Associate Professor

Associate Professor at University of Pittsburgh