Skip to content
Snippets Groups Projects
Commit e2ce6771 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

moving to pickle; formatting

parent ac8e8e7e
No related branches found
No related tags found
1 merge request!2moving to pickle; two threads per proc
......@@ -11,6 +11,7 @@ def plot(x, y, label, *args):
plt.plot(x, y, *args, label=label)
plt.legend()
def plot_shared(path, title):
model_path = os.path.join(path, "plots")
Path(model_path).mkdir(parents=True, exist_ok=True)
......@@ -29,7 +30,7 @@ def plot_shared(path, title):
current_params *= v
total_params += current_params
print("Total Params: ", str(total_params))
shared_count = np.zeros(total_params, dtype = int)
shared_count = np.zeros(total_params, dtype=int)
del model_vec["shapes"]
model_vec = np.array(model_vec[list(model_vec.keys())[0]])
shared_count[model_vec] += 1
......
import sys
from decentralizepy.datasets.Femnist import Femnist
from decentralizepy.datasets.Femnist import Femnist
if __name__ == "__main__":
f = Femnist(None, None, None)
......
import json
import logging
import pickle
from collections import deque
import zmq
......@@ -85,7 +86,7 @@ class TCP(Communication):
def encrypt(self, data):
"""
Encode data using utf8.
Encode data as python pickle.
Parameters
----------
......@@ -98,11 +99,11 @@ class TCP(Communication):
Encoded data
"""
return json.dumps(data).encode("utf8")
return pickle.dumps(data)
def decrypt(self, sender, data):
"""
Decode received data from utf8.
Decode received pickle data.
Parameters
----------
......@@ -118,7 +119,7 @@ class TCP(Communication):
"""
sender = int(sender.decode())
data = json.loads(data.decode("utf8"))
data = pickle.loads(data)
return sender, data
def connect_neighbors(self, neighbors):
......
from decentralizepy import utils
from decentralizepy.mappings.Mapping import Mapping
class Dataset:
"""
This class defines the Dataset API.
......
......@@ -109,12 +109,12 @@ class PartialModel(Sharing):
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Convert model to a dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
Model converted to a dict
"""
if self.alpha > self.metadata_cap: # Share fully
......@@ -164,10 +164,7 @@ class PartialModel(Sharing):
logging.info("Generated dictionary to send")
for key in m:
m[key] = json.dumps(m[key])
logging.info("Converted dictionary to json")
logging.info("Converted dictionary to pickle")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"]))
......@@ -175,12 +172,12 @@ class PartialModel(Sharing):
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Convert received dict to state_dict.
Parameters
----------
m : dict
json dict received
dict received
Returns
-------
......@@ -207,9 +204,9 @@ class PartialModel(Sharing):
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(json.loads(m["indices"]))
index_tensor = torch.tensor(m["indices"])
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(json.loads(m["params"]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
......
import json
import logging
import pickle
from collections import deque
import numpy
......@@ -90,28 +90,28 @@ class Sharing:
def serialized_model(self):
"""
Convert model to json dict. Here we can choose how much to share
Convert model to a dictionary. Here we can choose how much to share
Returns
-------
dict
Model converted to json dict
Model converted to dict
"""
m = dict()
for key, val in self.model.state_dict().items():
m[key] = json.dumps(val.numpy().tolist())
m[key] = val.numpy()
self.total_data += len(self.communication.encrypt(m[key]))
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Convert received dict to state_dict.
Parameters
----------
m : dict
json dict received
received dict
Returns
-------
......@@ -121,7 +121,7 @@ class Sharing:
"""
state_dict = dict()
for key, value in m.items():
state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
state_dict[key] = torch.from_numpy(value)
return state_dict
def step(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment