Federated Learning in Health AI Part 2: How Federated Learning Works
Learn how Federated Learning transforms Health AI through decentralized training, the FedAvg algorithm, and privacy-focused weight averaging that enables collaboration without data sharing.
In Part 1 of this series, we established the necessity for a mechanism like Federated Learning. To deal with the fragmented and sensitive nature of healthcare data, an approach that could eliminate the need to consolidate data into a centralized repository for training would be appropriate. Federated Learning addresses this by bringing the training to where the data resides; in our case, these are the healthcare institutions.
The Simplicity Behind Federated Learning
How does Federated Learning achieve decentralized training of a centralized model? Let’s start with a high level overview, based on the original FedAvg paper Communication-Efficient Learning of Deep Networks from Decentralized Data
For discussing how it works, let’s consider a running example of training an AI model to detect COVID-19 infections from chest X-rays (CXRs) of patients. Health Institutions around the country have their own sets of CXR data. Such an AI model would benefit if it could learn from the diverse CXRs from different institutions but without collecting the data itself.
FedAvg
Federated Training of AI models are carried out for N communication rounds. The central location where the model (hereby, referred to as the “Global Model”) resides is called the Server. The separate sites where data resides and training occurs are called the Clients. Each client could be a hospital in healthcare settings with its own copy of dataset with (CXR images, Diagnosis) pairs. At the start of each communication rounds the server sends a copy of the Global Model to each of the hospitals participating in federated learning.
Each hospital now maintains their copy of the global model as the “Local Model” and trains it through conventional supervised learning on their dataset. This goes on for a preset number of epochs L, referred as Local Epochs. Once local training at each hospital is completed, these models are sent back to a central server. So, how does the Global Model aggregate all the learnings from the different local training to update itself. Turns out, the solution is really simple. Averaging. This averaging step merges the collective learning from each institution into a single, improved global model that represents the experience of all without ever sharing sensitive data.
The mathematical formulation of this weighted averaging is presented below:
And the same process is repeated until N communication rounds are completed and we get the final trained model at the central Server’s location which combines the learnings from each individual dataset.
The Counterintuitive: Weight Averaging vs. Gradient Averaging
When we first hear about combining the learning of multiple models in Federated Learning, a natural question arises: “Shouldn’t we just collect the gradients from all clients and update the global model using the average gradient?” That seems intuitive and a more natural first school of thought while formulating Federated Learning. So why does FedAvg instead average model weights after local training instead of averaging gradients directly? Actually, the authors didn’t invent a new idea of averaging model weights. They instead proved that averaging weights after local updates is mathematically equivalent to averaging gradients first and then updating once.
Here’s how they explained it (it gets only a little more technical in this section):
A typical implementation of FedSGD (what they initially named the first approach) has each client compute the average gradient on its local data at the current global model.
The server then aggregates these gradients and performs the update:
An equivalent update can be expressed as:
For each client,
and then on the server,
In other words, each client takes one local gradient descent step, producing an updated local model. The server then averages these locally updated weights to obtain the new global model.
Once the algorithm was written this way, it became easy to extend from just one local gradient step to multiple steps before the averaging. This idea of adding local computation before aggregation is actually what led to the Federated Averaging algorithm. Therefore, instead of being a conceptual leap, the shift from gradient averaging to weight averaging is actually an implementation insight. This made the algorithm much more efficient in real-world federated settings.
We have now unpacked the core idea behind how Federated Learning actually works: training models collaboratively without ever sharing data. In the next part, we will move from how Federated Learning works to how it stays private by exploring privacy-preserving techniques that make it truly trustworthy for sensitive domains like healthcare.
If you found this blog valuable and would like to explore more insights on Federated Learning in Health AI, follow us on LinkedIn and subscribe to our Substack for future update
Bibek Niroula is an AI researcher at Multimodal Learning Lab, NAAMII . He is a computer engineer whose research interests include AI in Healthcare, Federated Learning, Multimodal Learning, and/or some form of the intersection of them. He also holds an ISC2 Certification in Cybersecurity and has a strong interest in supporting others in navigating the digital space safely.










