Skip to content
Snippets Groups Projects
Commit 0c1ab13e authored by Milos Vujasinovic's avatar Milos Vujasinovic
Browse files

Class name fix for SecureCompressedAggregation

parent a21933e8
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ from torch import multiprocessing as mp ...@@ -8,7 +8,7 @@ from torch import multiprocessing as mp
from decentralizepy import utils from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.SecureCompressedAggregatopn import SecureCompressedAggregatopn from decentralizepy.node.SecureCompressedAggregation import SecureCompressedAggregation
def read_ini(file_path): def read_ini(file_path):
...@@ -54,7 +54,7 @@ if __name__ == "__main__": ...@@ -54,7 +54,7 @@ if __name__ == "__main__":
for r in range(procs_per_machine): for r in range(procs_per_machine):
processes.append( processes.append(
mp.Process( mp.Process(
target=SecureCompressedAggregatopn, target=SecureCompressedAggregation,
args=[ args=[
r, r,
m_id, m_id,
......
...@@ -84,7 +84,7 @@ def temp_seed(seed): ...@@ -84,7 +84,7 @@ def temp_seed(seed):
finally: finally:
np.random.set_state(state) np.random.set_state(state)
class SecureCompressedAggregatopn(DPSGDNode): class SecureCompressedAggregation(DPSGDNode):
""" """
This class defines the node for secure compressed aggregation This class defines the node for secure compressed aggregation
...@@ -226,7 +226,8 @@ class SecureCompressedAggregatopn(DPSGDNode): ...@@ -226,7 +226,8 @@ class SecureCompressedAggregatopn(DPSGDNode):
def generate_mask(self, seed, size): def generate_mask(self, seed, size):
with temp_seed(seed): with temp_seed(seed):
return torch.Tensor(np.random.uniform(1, 10, size=size)) # Figure out best distribution to add
return torch.Tensor(np.random.normal(0, 100000, size=size))
def run(self): def run(self):
""" """
...@@ -244,7 +245,6 @@ class SecureCompressedAggregatopn(DPSGDNode): ...@@ -244,7 +245,6 @@ class SecureCompressedAggregatopn(DPSGDNode):
for iteration in range(self.iterations): for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration) logging.info("Starting training iteration: %d", iteration)
print("Starting training iteration:", iteration)
rounds_to_train_evaluate -= 1 rounds_to_train_evaluate -= 1
rounds_to_test -= 1 rounds_to_test -= 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment