Skip to content
Snippets Groups Projects
Commit 52720200 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add TopKPlusRandom

parent 44f0d32d
No related branches found
No related tags found
No related merge requests found
import logging
import numpy as np
import torch
from decentralizepy.sharing.PartialModel import PartialModel
class TopKPlusRandom(PartialModel):
"""
This class implements partial model sharing with some random additions.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
)
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients and put some extra random.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
logging.info("Summing up gradients")
assert len(self.model.accumulated_gradients) > 0
gradient_sum = self.model.accumulated_gradients[0]
for i in range(1, len(self.model.accumulated_gradients)):
for key in self.model.accumulated_gradients[i]:
gradient_sum[key] += self.model.accumulated_gradients[i][key]
logging.info("Returning topk gradients")
tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
G = torch.abs(torch.cat(tensors_to_cat, dim=0))
std, mean = torch.std_mean(G, unbiased=False)
self.std = std.item()
self.mean = mean.item()
elements_to_pick = round(self.alpha / 2.0 * G.shape[0])
G_topK = torch.topk(G, min(G.shape[0], elements_to_pick), dim=0, sorted=False)
more_indices = np.arange(G.shape[0], dtype=int)
np.delete(more_indices, G_topK[1].numpy())
more_indices = np.random.choice(
more_indices, min(more_indices.shape[0], elements_to_pick)
)
G_topK0 = torch.cat([G_topK[0], G[more_indices]], dim=0)
G_topK1 = torch.cat([G_topK[1], torch.tensor(more_indices)], dim=0)
return G_topK0, G_topK1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment