Skip to content
Snippets Groups Projects
Commit f99d4005 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Merge branch 'pickling_pr' into 'main'

moving to pickle; two threads per proc

See merge request sacs/decentralizepy!2
parents ac8e8e7e bee80272
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ def plot(x, y, label, *args):
plt.plot(x, y, *args, label=label)
plt.legend()
def plot_shared(path, title):
model_path = os.path.join(path, "plots")
Path(model_path).mkdir(parents=True, exist_ok=True)
......@@ -29,7 +30,7 @@ def plot_shared(path, title):
current_params *= v
total_params += current_params
print("Total Params: ", str(total_params))
shared_count = np.zeros(total_params, dtype = int)
shared_count = np.zeros(total_params, dtype=int)
del model_vec["shapes"]
model_vec = np.array(model_vec[list(model_vec.keys())[0]])
shared_count[model_vec] += 1
......
import sys
from decentralizepy.datasets.Femnist import Femnist
from decentralizepy.datasets.Femnist import Femnist
if __name__ == "__main__":
f = Femnist(None, None, None)
......
import json
import logging
import pickle
from collections import deque
import zmq
......@@ -85,7 +86,7 @@ class TCP(Communication):
def encrypt(self, data):
"""
Encode data using utf8.
Encode data as python pickle.
Parameters
----------
......@@ -98,11 +99,11 @@ class TCP(Communication):
Encoded data
"""
return json.dumps(data).encode("utf8")
return pickle.dumps(data)
def decrypt(self, sender, data):
"""
Decode received data from utf8.
Decode received pickle data.
Parameters
----------
......@@ -118,7 +119,7 @@ class TCP(Communication):
"""
sender = int(sender.decode())
data = json.loads(data.decode("utf8"))
data = pickle.loads(data)
return sender, data
def connect_neighbors(self, neighbors):
......
from decentralizepy import utils
from decentralizepy.mappings.Mapping import Mapping
class Dataset:
"""
This class defines the Dataset API.
......
......@@ -59,3 +59,16 @@ class Linear(Mapping):
"""
return (uid % self.procs_per_machine), (uid // self.procs_per_machine)
def get_local_procs_count(self):
"""
Gives number of processes that run on the node
Returns
-------
int
the number of local processes
"""
return self.procs_per_machine
......@@ -67,3 +67,16 @@ class Mapping:
"""
raise NotImplementedError
def get_local_procs_count(self):
"""
Gives number of processes that run on the node
Returns
-------
int
the number of local processes
"""
raise NotImplementedError
import importlib
import json
import logging
import math
import os
import torch
from matplotlib import pyplot as plt
from decentralizepy import utils
......@@ -420,6 +422,10 @@ class Node:
Other arguments
"""
total_threads = os.cpu_count()
threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1)
torch.set_num_threads(threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
......@@ -432,5 +438,8 @@ class Node:
test_after,
*args
)
logging.info(
"Each proc uses %d threads out of %d.", threads_per_proc, total_threads
)
self.run()
......@@ -109,12 +109,12 @@ class PartialModel(Sharing):
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Convert model to a dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
Model converted to a dict
"""
if self.alpha > self.metadata_cap: # Share fully
......@@ -155,19 +155,16 @@ class PartialModel(Sharing):
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = G_topk.numpy().tolist()
m["indices"] = G_topk.numpy()
m["params"] = T_topk.numpy().tolist()
m["params"] = T_topk.numpy()
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Generated dictionary to send")
for key in m:
m[key] = json.dumps(m[key])
logging.info("Converted dictionary to json")
logging.info("Converted dictionary to pickle")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"]))
......@@ -175,12 +172,12 @@ class PartialModel(Sharing):
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Convert received dict to state_dict.
Parameters
----------
m : dict
json dict received
dict received
Returns
-------
......@@ -207,9 +204,9 @@ class PartialModel(Sharing):
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(json.loads(m["indices"]))
index_tensor = torch.tensor(m["indices"])
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(json.loads(m["params"]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
......
import json
import logging
import pickle
from collections import deque
import numpy
......@@ -90,28 +90,28 @@ class Sharing:
def serialized_model(self):
"""
Convert model to json dict. Here we can choose how much to share
Convert model to a dictionary. Here we can choose how much to share
Returns
-------
dict
Model converted to json dict
Model converted to dict
"""
m = dict()
for key, val in self.model.state_dict().items():
m[key] = json.dumps(val.numpy().tolist())
m[key] = val.numpy()
self.total_data += len(self.communication.encrypt(m[key]))
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Convert received dict to state_dict.
Parameters
----------
m : dict
json dict received
received dict
Returns
-------
......@@ -121,7 +121,7 @@ class Sharing:
"""
state_dict = dict()
for key, value in m.items():
state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
state_dict[key] = torch.from_numpy(value)
return state_dict
def step(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment