Skip to content
Snippets Groups Projects
Commit 327519cc authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Debug TCP

parent 892cf251
No related branches found
No related tags found
No related merge requests found
import logging
from decentralizepy.sharing.PartialModel import PartialModel
from decentralizepy.sharing.Sharing import Sharing
class GrowingAlpha(PartialModel):
......@@ -18,9 +17,9 @@ class GrowingAlpha(PartialModel):
init_alpha=0.0,
max_alpha=1.0,
k=10,
metadata_cap=0.6,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
super().__init__(
rank,
......@@ -34,34 +33,20 @@ class GrowingAlpha(PartialModel):
init_alpha,
dict_ordered,
save_shared,
metadata_cap,
)
self.init_alpha = init_alpha
self.max_alpha = max_alpha
self.k = k
self.metadata_cap = metadata_cap
self.base = None
def step(self):
if (self.communication_round + 1) % self.k == 0:
self.alpha += (self.max_alpha - self.init_alpha) / self.k
self.alpha = min(self.alpha, 1.00)
if self.alpha == 0.0:
logging.info("Not sending/receiving data (alpha=0.0)")
self.communication_round += 1
return
if self.alpha > self.metadata_cap:
if self.base == None:
self.base = Sharing(
self.rank,
self.machine_id,
self.communication,
self.mapping,
self.graph,
self.model,
self.dataset,
)
self.base.communication_round = self.communication_round
self.base.step()
else:
super().step()
super().step()
......@@ -22,6 +22,7 @@ class PartialModel(Sharing):
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
"""
Constructor
......@@ -51,6 +52,7 @@ class PartialModel(Sharing):
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
# Only save for 2 procs
if rank == 0 or rank == 1:
......@@ -78,6 +80,9 @@ class PartialModel(Sharing):
)
def serialized_model(self):
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
......@@ -129,6 +134,9 @@ class PartialModel(Sharing):
return m
def deserialized_model(self, m):
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment