Skip to content
Snippets Groups Projects
Commit da846d2a authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Update random alpha

parent 1586bbba
Branches
Tags
No related merge requests found
......@@ -201,6 +201,8 @@ class PartialModel(Sharing):
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["indices"] = G_topk.numpy().astype(np.int32)
m["params"] = T_topk.numpy()
......
import random
from decentralizepy.sharing.PartialModel import PartialModel
from decentralizepy.utils import identity
class RandomAlpha(PartialModel):
......@@ -19,9 +20,14 @@ class RandomAlpha(PartialModel):
model,
dataset,
log_dir,
alpha_list=[0.1,0.2,0.3,0.4,1.0],
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
accumulation=False,
save_accumulated="",
change_transformer=identity,
accumulate_averaging_changes=False,
):
"""
Constructor
......@@ -65,6 +71,14 @@ class RandomAlpha(PartialModel):
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
change_transformer,
accumulate_averaging_changes
)
self.alpha_list = eval(alpha_list)
random.seed(
self.mapping.get_uid(self.rank, self.machine_id)
)
def step(self):
......@@ -72,8 +86,5 @@ class RandomAlpha(PartialModel):
Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
"""
random.seed(
self.mapping.get_uid(self.rank, self.machine_id) + self.communication_round
)
self.alpha = random.randint(1, 7) / 10.0
self.alpha = random.choice(self.alpha_list)
super().step()
import random
from decentralizepy.sharing.Wavelet import Wavelet
class RandomAlpha(Wavelet):
"""
This class implements the partial model sharing with a random alpha each iteration.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha_list=[0.1,0.2,0.3,0.4,1.0],
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
wavelet="haar",
level=4,
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes=False,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
1.0,
dict_ordered,
save_shared,
metadata_cap,
wavelet,
level,
change_based_selection,
save_accumulated,
accumulation,
accumulate_averaging_changes,
)
self.alpha_list = eval(alpha_list)
random.seed(
self.mapping.get_uid(self.rank, self.machine_id)
)
def step(self):
"""
Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
"""
self.alpha = random.choice(self.alpha_list)
super().step()
......@@ -179,7 +179,7 @@ class Wavelet(PartialModel):
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
if self.alpha >= self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
......@@ -218,6 +218,8 @@ class Wavelet(PartialModel):
m["indices"] = indices.numpy().astype(np.int32)
m["send_partial"] = True
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
......@@ -240,7 +242,7 @@ class Wavelet(PartialModel):
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
if "send_partial" not in m:
return super().deserialized_model(m)
with torch.no_grad():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment