This repository contains the implementation of a client selection mechanism for Federated Learning (FL) based on Multi-Agent Reinforcement Learning (MARL), using Value Decomposition Networks (VDN), with a focus on robustness against label flipping attacks.
In federated learning scenarios, the selection of clients participating in each aggregation round directly impacts the quality and robustness of the global model. Malicious clients can degrade model performance by sending poisoned updates.
This project proposes the use of multi-agent reinforcement learning agents to adaptively select clients, avoiding attackers and prioritizing honest clients based on contribution metrics.
- VDN (Value Decomposition Networks) with Double DQN and Prioritized Experience Replay (PER) for client selection
- Agent state metrics: gradient projection (proj), generalization loss (gener), staleness (estag) and selection streak (serie)
- Attack: Deterministic Targeted Label Flipping with configurable attacker fraction
- Aggregation mechanism: FedAvg
- Data distribution: Non-IID Dirichlet split with configurable alpha
Each round is divided into two phases:
- Metrics phase — all 50 clients train for
local_steps. The metrics (proj,gener) are computed and used as state variables. - Training phase — only the K clients selected by the agent train for
local_epochs. The deltas are aggregated via FedAvg.
The local observation vector of each client
Gradient projection:
Generalization loss:
Staleness:
Selection streak:
The implemented attack is a targeted label flipping, where each class is mapped to a visually similar class following a fixed mapping:
| Original | Flipped |
|---|---|
| airplane | ship |
| ship | airplane |
| automobile | truck |
| truck | automobile |
| cat | dog |
| dog | cat |
| deer | horse |
| horse | deer |
| bird | frog |
| frog | bird |
Unlike random flipping, this approach is more realistic and harder to detect,
as the model learns confusions between visually similar classes. The attack_rate
parameter controls the fraction of samples flipped per attacking client.
git clone https://github.com/GTA-UFRJ/FEDMARL.gitpip install -r requirements.txtpython main.pyThe main hyperparameters are passed directly to run_experiment():
run_experiment(
rounds=350,
n_clients=50,
k_select=15,
dir_alpha=0.3,
run_random=True, # runs random selection track
run_vdn=True, # runs VDN track
initial_flip_fraction=0.4,
flip_rate_initial=1.0,
local_lr=0.005,
local_steps=10,
local_epochs=5,
marl_lr=1e-4,
)Results are automatically saved to a .json file in the configured output directory.
ClientSelection/
+-- main.py # entry point, hyperparameter configuration
+-- experiment.py # main experiment loop (RANDOM and VDN tracks)
+-- server.py # local training, aggregation, server metrics
+-- agent.py # VDNSelector, AgentMLP, PrioritizedReplayJoint
+-- metrics.py # eval_acc, eval_loss, probing_loss, windowed_reward
+-- data.py # Dirichlet split and label flipping dataset
+-- model.py # ResNet18 adapted for CIFAR-10
+-- config.py # DEVICE, SEED, seed_worker
+-- flower/ # experimental implementation with Flower 1.26 (in development)
+-- pyproject.toml
+-- vdn_fl/
+-- client_app.py
+-- server_app.py
+-- data.py
+-- ...
The project went through three main development stages, each revealing limitations and motivating the following improvements. The examples below adopt the same base configuration: N = 50 clients, K = 15 selected per round, 40% attacking clients with full label inversion (100% label flipping).
The initial version used a simple CNN (SmallCNN):
| Layer | Configuration |
|---|---|
| Input | Conv(3,3,32) + Pool(2×2) |
| Layer 2 | Conv(3,32,64) + Pool(2×2) |
| Layer 3 | Conv(3,64,128) + Pool(2×2) |
| Output | FC(2048, 256, 10) |
| Optimizer | SGD (momentum=0.9, lr=0.01) |
With this architecture the VDN agent already demonstrated superiority over random selection (FedAvg), reaching ~67% accuracy against ~55% for FedAvg with 40% attacking clients over 500 rounds. The small network, by generating lower magnitude deltas, exhibited natural stability against attacks.
Replacing with ResNet18 adapted for CIFAR-10 (3×3 conv1, no maxpool, standard BatchNorm) and adding data augmentation during local training aimed to increase model capacity and bring results closer to state of the art. However, without stabilization mechanisms, the higher magnitude deltas from ResNet18 drastically amplified the impact of attackers, causing sharp and recurrent accuracy drops that made training unstable.
Adding three mechanisms to the aggregation resolved the instability:
| Mechanism | Configuration (N=50) | Configuration (N=100) | Effect |
|---|---|---|---|
| Norm filtering | 2.0 × median_norm |
2.0 × median_norm |
Discards deltas with anomalous norm before aggregation |
| Gradient clipping | 0.25 × median_norm |
0.1 × median_norm |
Limits the total update magnitude per round |
| FedMedian | — | — | Aggregates by coordinate-wise median |
With Those mechanisms, the VDN agent maintains stable accuracy around 85% over 350 rounds while consistently selecting honest clients. Random selection oscillates continuously due to the presence of attackers.
N=50 clients, K=15 selected per round, 40% attackers:
N=100 clients, K=15 selected per round, 40% attackers:
The client selection histogram below confirms that the learned policy systematically prioritizes honest clients over attackers throughout training (N=100):
Every 20 rounds, the server prints the client ranking ordered by advantage (adv = Q1 - Q0). Clients with positive adv are prioritized for selection. The result below illustrates the separation learned by MARL, showing that the policy consistently ranks honest clients above attackers:
| Position | Client | Type | adv |
|---|---|---|---|
| 1st | 41 | HONEST | +0.083774 |
| 2nd | 06 | HONEST | +0.074431 |
| 3rd | 23 | HONEST | +0.070137 |
| 4th | 24 | HONEST | +0.068502 |
| ... | ... | ... | ... |
| 47th | 12 | ATTACKER | -0.144135 |
| 48th | 31 | ATTACKER | -0.142081 |
| 49th | 12 | ATTACKER | -0.144135 |
| 50th | 30 | ATTACKER | -0.187876 |




