Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
Commits on Source (16)
Showing
with 833 additions and 243 deletions
...@@ -8,4 +8,6 @@ ...@@ -8,4 +8,6 @@
**/leaf/ **/leaf/
**.egg-info **.egg-info
2021** 2021**
2022** 2022**
\ No newline at end of file **/massif.out*
*swp
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
decentralizepy decentralizepy
============== ==============
decentralizepy is a framework for running distributed applications (particularly ML) on top of arbitrary topologies (decentralized, federated, parameter server).
It was primarily conceived for assessing scientific ideas on several aspects of distributed learning (communication efficiency, privacy, data heterogeneity etc.).
------------------------- -------------------------
Setting up decentralizepy Setting up decentralizepy
------------------------- -------------------------
...@@ -23,10 +26,14 @@ Setting up decentralizepy ...@@ -23,10 +26,14 @@ Setting up decentralizepy
pip install --upgrade pip pip install --upgrade pip
* On Mac M1, installing ``pyzmq`` fails with `pip`. Use `conda <https://conda.io>`_. * On Mac M1, installing ``pyzmq`` fails with `pip`. Use `conda <https://conda.io>`_.
* Install decentralizepy for development. :: * Install decentralizepy for development. (zsh) ::
pip3 install --editable .\[dev\] pip3 install --editable .\[dev\]
* Install decentralizepy for development. (bash) ::
pip3 install --editable .[dev]
---------------- ----------------
Running the code Running the code
---------------- ----------------
......
...@@ -26,9 +26,9 @@ def get_stats(l): ...@@ -26,9 +26,9 @@ def get_stats(l):
return mean_dict, stdev_dict, min_dict, max_dict return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc): def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"):
plt.title(title) plt.title(title)
plt.xlabel("communication rounds") plt.xlabel(xlabel)
x_axis = np.array(list(means.keys())) x_axis = np.array(list(means.keys()))
y_axis = np.array(list(means.values())) y_axis = np.array(list(means.values()))
err = np.array(list(stdevs.values())) err = np.array(list(stdevs.values()))
...@@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc): ...@@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
plt.legend(loc=loc) plt.legend(loc=loc)
def replace_dict_key(d_org: dict, d_other: dict):
result = {}
for x, y in d_org.items():
result[d_other[x]] = y
return result
def plot_results(path, centralized, data_machine="machine0", data_node=0): def plot_results(path, centralized, data_machine="machine0", data_node=0):
folders = os.listdir(path) folders = os.listdir(path)
if centralized.lower() in ["true", "1", "t", "y", "yes"]: if centralized.lower() in ["true", "1", "t", "y", "yes"]:
...@@ -67,26 +74,61 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): ...@@ -67,26 +74,61 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
filepath = os.path.join(mf_path, f) filepath = os.path.join(mf_path, f)
with open(filepath, "r") as inf: with open(filepath, "r") as inf:
results.append(json.load(inf)) results.append(json.load(inf))
if folder.startswith("FL"): if folder.startswith("FL") or folder.startswith("Parameter Server"):
data_node = -1 data_node = -1
else: else:
data_node = 0 data_node = 0
with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f: with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
main_data = json.load(f) main_data = json.load(f)
main_data = [main_data] main_data = [main_data]
# Plotting bytes over time
plt.figure(10)
b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
df = pd.DataFrame(
{
"mean": list(b_means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(b_means),
},
list(b_means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
)
# Plot Training loss # Plot Training loss
plt.figure(1) plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results]) means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
correct_bytes = [b_means[x] for x in means]
df = pd.DataFrame( df = pd.DataFrame(
{ {
"mean": list(means.values()), "mean": list(means.values()),
"std": list(stdevs.values()), "std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means), "nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
}, },
list(means.keys()), list(means.keys()),
columns=["mean", "std", "nr_nodes"], columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(11)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Training Loss",
folder,
"upper right",
"Total Bytes per node",
) )
df.to_csv( df.to_csv(
os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds" os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
) )
...@@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): ...@@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()), "mean": list(means.values()),
"std": list(stdevs.values()), "std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means), "nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
}, },
list(means.keys()), list(means.keys()),
columns=["mean", "std", "nr_nodes"], columns=["mean", "std", "nr_nodes", "total_bytes"],
) )
plt.figure(12)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Testing Loss",
folder,
"upper right",
"Total Bytes per node",
)
df.to_csv( df.to_csv(
os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds" os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
) )
...@@ -121,30 +177,27 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): ...@@ -121,30 +177,27 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()), "mean": list(means.values()),
"std": list(stdevs.values()), "std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means), "nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
}, },
list(means.keys()), list(means.keys()),
columns=["mean", "std", "nr_nodes"], columns=["mean", "std", "nr_nodes", "total_bytes"],
)
df.to_csv(
os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
) )
plt.figure(6) plt.figure(13)
means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results]) means = replace_dict_key(means, b_means)
plot( plot(
means, means,
stdevs, stdevs,
mins, mins,
maxs, maxs,
"Gradient Variation over Nodes", "Testing Accuracy",
folder, folder,
"upper right", "lower right",
"Total Bytes per node",
) )
# Plot Testing loss df.to_csv(
plt.figure(7) os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
means, stdevs, mins, maxs = get_stats([x["grad_mean"] for x in results])
plot(
means, stdevs, mins, maxs, "Gradient Magnitude Mean", folder, "upper right"
) )
# Collect total_bytes shared # Collect total_bytes shared
bytes_list = [] bytes_list = []
for x in results: for x in results:
...@@ -173,16 +226,21 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0): ...@@ -173,16 +226,21 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
data_means[folder] = list(means.values())[0] data_means[folder] = list(means.values())[0]
data_stdevs[folder] = list(stdevs.values())[0] data_stdevs[folder] = list(stdevs.values())[0]
plt.figure(10)
plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300)
plt.figure(11)
plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300)
plt.figure(12)
plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300)
plt.figure(13)
plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300)
plt.figure(1) plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"), dpi=300) plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
plt.figure(2) plt.figure(2)
plt.savefig(os.path.join(path, "test_loss.png"), dpi=300) plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
plt.figure(3) plt.figure(3)
plt.savefig(os.path.join(path, "test_acc.png"), dpi=300) plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
plt.figure(6)
plt.savefig(os.path.join(path, "grad_std.png"), dpi=300)
plt.figure(7)
plt.savefig(os.path.join(path, "grad_mean.png"), dpi=300)
# Plot total_bytes # Plot total_bytes
plt.figure(4) plt.figure(4)
plt.title("Data Shared") plt.title("Data Shared")
......
#!/bin/bash #!/bin/bash
script_path=$(realpath $(dirname $0))
decpy_path=/mnt/nfs/kirsten/Gitlab/jac_decentralizepy/decentralizepy/eval # Working directory, where config files are read from and logs are written.
decpy_path=/mnt/nfs/$(whoami)/decpy_workingdir
cd $decpy_path cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3 # Python interpreter
graph=/mnt/nfs/kirsten/Gitlab/tutorial/regular_16.txt env_python=python3
original_config=/mnt/nfs/kirsten/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config_celeba_sharing.ini # File regular_16.txt is available in /tutorial
graph=$decpy_path/regular_16.txt
# File config_celeba_sharing.ini is available in /tutorial
# In this config file, change addresses_filepath to correspond to your list of machines (example in /tutorial/ip.json)
original_config=$decpy_path/config_celeba_sharing.ini
# Local config file
config_file=/tmp/$(basename $original_config)
# Python script to be executed
eval_file=$script_path/testingPeerSampler.py
# General parameters
procs_per_machine=8 procs_per_machine=8
machines=2 machines=2
iterations=5 iterations=5
test_after=2 test_after=2
eval_file=testingFederated.py
#eval_file=testingPeerSampler.py
log_level=INFO log_level=INFO
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2` m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
...@@ -20,6 +33,8 @@ echo M is $m ...@@ -20,6 +33,8 @@ echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir mkdir -p $log_dir
# Copy and manipulate the local config file
cp $original_config $config_file cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file # echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -ctr 0 -cte 0 -wsd $log_dir
\ No newline at end of file $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wsd $log_dir
...@@ -44,6 +44,7 @@ machines=6 ...@@ -44,6 +44,7 @@ machines=6
global_epochs=100 global_epochs=100
eval_file=testingFederated.py eval_file=testingFederated.py
log_level=INFO log_level=INFO
working_rate=0.1
ip_machines=$nfs_home/configs/ip_addr_6Machines.json ip_machines=$nfs_home/configs/ip_addr_6Machines.json
...@@ -104,7 +105,7 @@ do ...@@ -104,7 +105,7 @@ do
$python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
$python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
$python_bin/crudini --set $config_file DATASET random_seed $seed $python_bin/crudini --set $config_file DATASET random_seed $seed
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wr $working_rate
echo $i is done echo $i is done
sleep 200 sleep 200
echo end of sleep echo end of sleep
......
...@@ -50,22 +50,30 @@ if __name__ == "__main__": ...@@ -50,22 +50,30 @@ if __name__ == "__main__":
l = Linear(n_machines, procs_per_machine) l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id m_id = args.machine_id
mp.spawn( processes = []
fn=DPSGDNode, for r in range(procs_per_machine):
nprocs=procs_per_machine, processes.append(
args=[ mp.Process(
m_id, target=DPSGDNode,
l, args=[
g, r,
my_config, m_id,
args.iterations, l,
args.log_dir, g,
args.weights_store_dir, my_config,
log_level[args.log_level], args.iterations,
args.test_after, args.log_dir,
args.train_evaluate_after, args.weights_store_dir,
args.reset_optimizer, log_level[args.log_level],
args.centralized_train_eval, args.test_after,
args.centralized_test_eval, args.train_evaluate_after,
], args.reset_optimizer,
) ],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
...@@ -8,8 +8,8 @@ from torch import multiprocessing as mp ...@@ -8,8 +8,8 @@ from torch import multiprocessing as mp
from decentralizepy import utils from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.FederatedParameterServer import FederatedParameterServer
from decentralizepy.node.DPSGDNodeFederated import DPSGDNodeFederated from decentralizepy.node.DPSGDNodeFederated import DPSGDNodeFederated
from decentralizepy.node.FederatedParameterServer import FederatedParameterServer
def read_ini(file_path): def read_ini(file_path):
...@@ -54,9 +54,6 @@ if __name__ == "__main__": ...@@ -54,9 +54,6 @@ if __name__ == "__main__":
sm = args.server_machine sm = args.server_machine
sr = args.server_rank sr = args.server_rank
# TODO
working_fraction = 1.0
processes = [] processes = []
if sm == m_id: if sm == m_id:
processes.append( processes.append(
...@@ -74,7 +71,7 @@ if __name__ == "__main__": ...@@ -74,7 +71,7 @@ if __name__ == "__main__":
log_level[args.log_level], log_level[args.log_level],
args.test_after, args.test_after,
args.train_evaluate_after, args.train_evaluate_after,
working_fraction, args.working_rate,
], ],
) )
) )
......
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.KNN import KNN
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
processes = []
for r in range(procs_per_machine):
processes.append(
mp.Process(
target=KNN,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
...@@ -9,10 +9,7 @@ from decentralizepy import utils ...@@ -9,10 +9,7 @@ from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
from decentralizepy.node.PeerSampler import PeerSampler from decentralizepy.node.PeerSampler import PeerSampler
from decentralizepy.node.ParameterServer import ParameterServer
from decentralizepy.node.DPSGDNodeWithParameterServer import DPSGDNodeWithParameterServer
def read_ini(file_path): def read_ini(file_path):
...@@ -62,8 +59,7 @@ if __name__ == "__main__": ...@@ -62,8 +59,7 @@ if __name__ == "__main__":
processes.append( processes.append(
mp.Process( mp.Process(
# target=PeerSamplerDynamic, # target=PeerSamplerDynamic,
target=ParameterServer, target=PeerSampler,
# target=PeerSampler,
args=[ args=[
sr, sr,
m_id, m_id,
...@@ -80,8 +76,7 @@ if __name__ == "__main__": ...@@ -80,8 +76,7 @@ if __name__ == "__main__":
for r in range(0, procs_per_machine): for r in range(0, procs_per_machine):
processes.append( processes.append(
mp.Process( mp.Process(
target=DPSGDNodeWithParameterServer, target=DPSGDWithPeerSampler,
# target=DPSGDWithPeerSampler,
args=[ args=[
r, r,
m_id, m_id,
...@@ -95,8 +90,6 @@ if __name__ == "__main__": ...@@ -95,8 +90,6 @@ if __name__ == "__main__":
args.test_after, args.test_after,
args.train_evaluate_after, args.train_evaluate_after,
args.reset_optimizer, args.reset_optimizer,
args.centralized_train_eval,
args.centralized_test_eval,
], ],
) )
) )
......
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
sm = args.server_machine
sr = args.server_rank
processes = []
if sm == m_id:
processes.append(
mp.Process(
target=PeerSamplerDynamic,
args=[
sr,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
log_level[args.log_level],
],
)
)
for r in range(0, procs_per_machine):
processes.append(
mp.Process(
target=DPSGDWithPeerSampler,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
...@@ -47,6 +47,7 @@ class TCP(Communication): ...@@ -47,6 +47,7 @@ class TCP(Communication):
total_procs, total_procs,
addresses_filepath, addresses_filepath,
offset=9000, offset=9000,
recv_timeout=50,
): ):
""" """
Constructor Constructor
...@@ -79,11 +80,14 @@ class TCP(Communication): ...@@ -79,11 +80,14 @@ class TCP(Communication):
self.machine_id = machine_id self.machine_id = machine_id
self.mapping = mapping self.mapping = mapping
self.offset = offset self.offset = offset
self.recv_timeout = recv_timeout
self.uid = mapping.get_uid(rank, machine_id) self.uid = mapping.get_uid(rank, machine_id)
self.identity = str(self.uid).encode() self.identity = str(self.uid).encode()
self.context = zmq.Context() self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER) self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity) self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.setsockopt(zmq.RCVTIMEO, self.recv_timeout)
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.bind(self.addr(rank, machine_id)) self.router.bind(self.addr(rank, machine_id))
self.total_data = 0 self.total_data = 0
...@@ -170,7 +174,7 @@ class TCP(Communication): ...@@ -170,7 +174,7 @@ class TCP(Communication):
id = str(neighbor).encode() id = str(neighbor).encode()
return id in self.peer_sockets return id in self.peer_sockets
def receive(self): def receive(self, block=True):
""" """
Returns ONE message received. Returns ONE message received.
...@@ -185,10 +189,19 @@ class TCP(Communication): ...@@ -185,10 +189,19 @@ class TCP(Communication):
If received HELLO If received HELLO
""" """
while True:
sender, recv = self.router.recv_multipart() try:
s, r = self.decrypt(sender, recv) sender, recv = self.router.recv_multipart()
return s, r s, r = self.decrypt(sender, recv)
return s, r
except zmq.ZMQError as exc:
if exc.errno == zmq.EAGAIN:
if not block:
return None
else:
continue
else:
raise
def send(self, uid, data, encrypt=True): def send(self, uid, data, encrypt=True):
""" """
......
...@@ -114,6 +114,8 @@ class CIFAR10(Dataset): ...@@ -114,6 +114,8 @@ class CIFAR10(Dataset):
test_batch_size, test_batch_size,
) )
self.num_classes = NUM_CLASSES
self.partition_niid = partition_niid self.partition_niid = partition_niid
self.shards = shards self.shards = shards
self.transform = transforms.Compose( self.transform = transforms.Compose(
......
...@@ -230,6 +230,8 @@ class Celeba(Dataset): ...@@ -230,6 +230,8 @@ class Celeba(Dataset):
self.IMAGES_DIR = utils.conditional_value(images_dir, "", None) self.IMAGES_DIR = utils.conditional_value(images_dir, "", None)
assert self.IMAGES_DIR != None assert self.IMAGES_DIR != None
self.num_classes = NUM_CLASSES
if self.__training__: if self.__training__:
self.load_trainset() self.load_trainset()
......
...@@ -52,6 +52,7 @@ class Dataset: ...@@ -52,6 +52,7 @@ class Dataset:
self.test_dir = utils.conditional_value(test_dir, "", None) self.test_dir = utils.conditional_value(test_dir, "", None)
self.sizes = utils.conditional_value(sizes, "", None) self.sizes = utils.conditional_value(sizes, "", None)
self.test_batch_size = utils.conditional_value(test_batch_size, "", 64) self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
self.num_classes = None
if self.sizes: if self.sizes:
if type(self.sizes) == str: if type(self.sizes) == str:
self.sizes = eval(self.sizes) self.sizes = eval(self.sizes)
...@@ -66,6 +67,20 @@ class Dataset: ...@@ -66,6 +67,20 @@ class Dataset:
else: else:
self.__testing__ = False self.__testing__ = False
self.label_distribution = None
def get_label_distribution(self):
# Only supported for classification
if self.label_distribution == None:
self.label_distribution = [0 for _ in range(self.num_classes)]
tr_set = self.get_trainset()
for _, ys in tr_set:
for y in ys:
y_val = y.item()
self.label_distribution[y_val] += 1
return self.label_distribution
def get_trainset(self): def get_trainset(self):
""" """
Function to get the training set Function to get the training set
......
...@@ -223,6 +223,8 @@ class Femnist(Dataset): ...@@ -223,6 +223,8 @@ class Femnist(Dataset):
test_batch_size, test_batch_size,
) )
self.num_classes = NUM_CLASSES
if self.__training__: if self.__training__:
self.load_trainset() self.load_trainset()
......
...@@ -9,12 +9,9 @@ import torch ...@@ -9,12 +9,9 @@ import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from decentralizepy import utils from decentralizepy import utils
from decentralizepy.communication.TCP import TCP
from decentralizepy.graphs.Graph import Graph from decentralizepy.graphs.Graph import Graph
from decentralizepy.graphs.Star import Star
from decentralizepy.mappings.Mapping import Mapping from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node from decentralizepy.node.Node import Node
from decentralizepy.train_test_evaluation import TrainTestHelper
class DPSGDNode(Node): class DPSGDNode(Node):
...@@ -70,42 +67,18 @@ class DPSGDNode(Node): ...@@ -70,42 +67,18 @@ class DPSGDNode(Node):
rounds_to_train_evaluate = self.train_evaluate_after rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1 global_epoch = 1
change = 1 change = 1
if self.uid == 0:
dataset = self.dataset
if self.centralized_train_eval:
dataset_params_copy = self.dataset_params.copy()
if "sizes" in dataset_params_copy:
del dataset_params_copy["sizes"]
self.whole_dataset = self.dataset_class(
self.rank,
self.machine_id,
self.mapping,
sizes=[1.0],
**dataset_params_copy
)
dataset = self.whole_dataset
if self.centralized_test_eval:
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
self.testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
for iteration in range(self.iterations): for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration) logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
self.iteration = iteration self.iteration = iteration
self.trainer.train(self.dataset) self.trainer.train(self.dataset)
new_neighbors = self.get_neighbors() new_neighbors = self.get_neighbors()
# The following code does not work because TCP sockets are supposed to be long lived.
# for neighbor in self.my_neighbors: # for neighbor in self.my_neighbors:
# if neighbor not in new_neighbors: # if neighbor not in new_neighbors:
# logging.info("Removing neighbor {}".format(neighbor)) # logging.info("Removing neighbor {}".format(neighbor))
...@@ -163,8 +136,6 @@ class DPSGDNode(Node): ...@@ -163,8 +136,6 @@ class DPSGDNode(Node):
"total_bytes": {}, "total_bytes": {},
"total_meta": {}, "total_meta": {},
"total_data_per_n": {}, "total_data_per_n": {},
"grad_mean": {},
"grad_std": {},
} }
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
...@@ -177,14 +148,8 @@ class DPSGDNode(Node): ...@@ -177,14 +148,8 @@ class DPSGDNode(Node):
results_dict["total_data_per_n"][ results_dict["total_data_per_n"][
iteration + 1 iteration + 1
] = self.communication.total_data ] = self.communication.total_data
if hasattr(self.sharing, "mean"):
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
if hasattr(self.sharing, "std"):
results_dict["grad_std"][iteration + 1] = self.sharing.std
rounds_to_train_evaluate -= 1 if rounds_to_train_evaluate == 0:
if rounds_to_train_evaluate == 0 and not self.centralized_train_eval:
logging.info("Evaluating on train set.") logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset) loss_after_sharing = self.trainer.eval_loss(self.dataset)
...@@ -197,26 +162,12 @@ class DPSGDNode(Node): ...@@ -197,26 +162,12 @@ class DPSGDNode(Node):
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)), os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
) )
rounds_to_test -= 1
if self.dataset.__testing__ and rounds_to_test == 0: if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change rounds_to_test = self.test_after * change
if self.centralized_test_eval: logging.info("Evaluating on test set.")
if self.uid == 0: ta, tl = self.dataset.test(self.model, self.loss)
ta, tl, trl = tthelper.train_test_evaluation(iteration) results_dict["test_acc"][iteration + 1] = ta
results_dict["test_acc"][iteration + 1] = ta results_dict["test_loss"][iteration + 1] = tl
results_dict["test_loss"][iteration + 1] = tl
if trl is not None:
results_dict["train_loss"][iteration + 1] = trl
else:
self.testing_comm.send(0, self.model.get_weights())
sender, data = self.testing_comm.receive()
assert sender == 0 and data == "finished"
else:
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49: if global_epoch == 49:
change *= 2 change *= 2
...@@ -253,8 +204,6 @@ class DPSGDNode(Node): ...@@ -253,8 +204,6 @@ class DPSGDNode(Node):
test_after, test_after,
train_evaluate_after, train_evaluate_after,
reset_optimizer, reset_optimizer,
centralized_train_eval,
centralized_test_eval,
): ):
""" """
Instantiate object field with arguments. Instantiate object field with arguments.
...@@ -281,10 +230,6 @@ class DPSGDNode(Node): ...@@ -281,10 +230,6 @@ class DPSGDNode(Node):
Number of iterations after which the train loss is calculated Number of iterations after which the train loss is calculated
reset_optimizer : int reset_optimizer : int
1 if optimizer should be reset every communication round, else 0 1 if optimizer should be reset every communication round, else 0
centralized_train_eval : bool
If set the train set evaluation happens at the node with uid 0
centralized_test_eval : bool
If set the train set evaluation happens at the node with uid 0
""" """
self.rank = rank self.rank = rank
self.machine_id = machine_id self.machine_id = machine_id
...@@ -297,17 +242,12 @@ class DPSGDNode(Node): ...@@ -297,17 +242,12 @@ class DPSGDNode(Node):
self.test_after = test_after self.test_after = test_after
self.train_evaluate_after = train_evaluate_after self.train_evaluate_after = train_evaluate_after
self.reset_optimizer = reset_optimizer self.reset_optimizer = reset_optimizer
self.centralized_train_eval = centralized_train_eval
self.centralized_test_eval = centralized_test_eval
self.sent_disconnections = False self.sent_disconnections = False
logging.info("Rank: %d", self.rank) logging.info("Rank: %d", self.rank)
logging.info("type(graph): %s", str(type(self.rank))) logging.info("type(graph): %s", str(type(self.rank)))
logging.info("type(mapping): %s", str(type(self.mapping))) logging.info("type(mapping): %s", str(type(self.mapping)))
if centralized_test_eval or centralized_train_eval:
self.star = Star(self.mapping.get_n_procs())
def init_comm(self, comm_configs): def init_comm(self, comm_configs):
""" """
Instantiate communication module from config. Instantiate communication module from config.
...@@ -322,17 +262,6 @@ class DPSGDNode(Node): ...@@ -322,17 +262,6 @@ class DPSGDNode(Node):
comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None) self.addresses_filepath = comm_params.get("addresses_filepath", None)
if self.centralized_test_eval:
self.testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
self.communication = comm_class( self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
) )
...@@ -351,8 +280,6 @@ class DPSGDNode(Node): ...@@ -351,8 +280,6 @@ class DPSGDNode(Node):
test_after=5, test_after=5,
train_evaluate_after=1, train_evaluate_after=1,
reset_optimizer=1, reset_optimizer=1,
centralized_train_eval=False,
centralized_test_eval=True,
*args *args
): ):
""" """
...@@ -384,10 +311,6 @@ class DPSGDNode(Node): ...@@ -384,10 +311,6 @@ class DPSGDNode(Node):
Number of iterations after which the train loss is calculated Number of iterations after which the train loss is calculated
reset_optimizer : int reset_optimizer : int
1 if optimizer should be reset every communication round, else 0 1 if optimizer should be reset every communication round, else 0
centralized_train_eval : bool
If set the train set evaluation happens at the node with uid 0
centralized_test_eval : bool
If set the train set evaluation happens at the node with uid 0
args : optional args : optional
Other arguments Other arguments
...@@ -407,8 +330,6 @@ class DPSGDNode(Node): ...@@ -407,8 +330,6 @@ class DPSGDNode(Node):
test_after, test_after,
train_evaluate_after, train_evaluate_after,
reset_optimizer, reset_optimizer,
centralized_train_eval,
centralized_test_eval,
) )
self.init_dataset_model(config["DATASET"]) self.init_dataset_model(config["DATASET"])
self.init_optimizer(config["OPTIMIZER_PARAMS"]) self.init_optimizer(config["OPTIMIZER_PARAMS"])
...@@ -423,7 +344,6 @@ class DPSGDNode(Node): ...@@ -423,7 +344,6 @@ class DPSGDNode(Node):
self.init_sharing(config["SHARING"]) self.init_sharing(config["SHARING"])
self.peer_deques = dict() self.peer_deques = dict()
self.connect_neighbors() self.connect_neighbors()
# self.instantiate_peer_deques()
def received_from_all(self): def received_from_all(self):
""" """
...@@ -454,8 +374,6 @@ class DPSGDNode(Node): ...@@ -454,8 +374,6 @@ class DPSGDNode(Node):
test_after=5, test_after=5,
train_evaluate_after=1, train_evaluate_after=1,
reset_optimizer=1, reset_optimizer=1,
centralized_train_eval=0,
centralized_test_eval=1,
*args *args
): ):
""" """
...@@ -499,19 +417,10 @@ class DPSGDNode(Node): ...@@ -499,19 +417,10 @@ class DPSGDNode(Node):
Number of iterations after which the train loss is calculated Number of iterations after which the train loss is calculated
reset_optimizer : int reset_optimizer : int
1 if optimizer should be reset every communication round, else 0 1 if optimizer should be reset every communication round, else 0
centralized_train_eval : int
If set then the train set evaluation happens at the node with uid 0.
Note: If it is True then centralized_test_eval needs to be true as well!
centralized_test_eval : int
If set then the trainset evaluation happens at the node with uid 0
args : optional args : optional
Other arguments Other arguments
""" """
centralized_train_eval = centralized_train_eval == 1
centralized_test_eval = centralized_test_eval == 1
# If centralized_train_eval is True then centralized_test_eval needs to be true as well!
assert not centralized_train_eval or centralized_test_eval
total_threads = os.cpu_count() total_threads = os.cpu_count()
self.threads_per_proc = max( self.threads_per_proc = max(
...@@ -532,8 +441,6 @@ class DPSGDNode(Node): ...@@ -532,8 +441,6 @@ class DPSGDNode(Node):
test_after, test_after,
train_evaluate_after, train_evaluate_after,
reset_optimizer, reset_optimizer,
centralized_train_eval == 1,
centralized_test_eval == 1,
*args *args
) )
logging.info( logging.info(
......
import importlib import importlib
import json
import logging import logging
import math import math
import os import os
...@@ -35,19 +36,15 @@ class DPSGDNodeFederated(Node): ...@@ -35,19 +36,15 @@ class DPSGDNodeFederated(Node):
del data["iteration"] del data["iteration"]
del data["CHANNEL"] del data["CHANNEL"]
if iteration == 0: self.model.load_state_dict(data["params"])
del data["degree"]
data = self.sharing.deserialized_model(data)
self.model.load_state_dict(data)
self.sharing._post_step() self.sharing._post_step()
self.sharing.communication_round += 1 self.sharing.communication_round += 1
logging.info("Received worker request at node {}, global iteration {}, local round {}".format( logging.info(
self.uid, "Received worker request at node {}, global iteration {}, local round {}".format(
iteration, self.uid, iteration, self.participated
self.participated )
)) )
if self.reset_optimizer: if self.reset_optimizer:
self.optimizer = self.optimizer_class( self.optimizer = self.optimizer_class(
...@@ -64,6 +61,38 @@ class DPSGDNodeFederated(Node): ...@@ -64,6 +61,38 @@ class DPSGDNodeFederated(Node):
to_send["CHANNEL"] = "DPSGD" to_send["CHANNEL"] = "DPSGD"
self.communication.send(self.parameter_server_uid, to_send) self.communication.send(self.parameter_server_uid, to_send)
if self.participated > 0:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
self.participated += 1 self.participated += 1
# only if has participated in learning # only if has participated in learning
...@@ -84,7 +113,7 @@ class DPSGDNodeFederated(Node): ...@@ -84,7 +113,7 @@ class DPSGDNodeFederated(Node):
weights_store_dir, weights_store_dir,
test_after, test_after,
train_evaluate_after, train_evaluate_after,
reset_optimizer reset_optimizer,
): ):
""" """
Instantiate object field with arguments. Instantiate object field with arguments.
...@@ -141,8 +170,7 @@ class DPSGDNodeFederated(Node): ...@@ -141,8 +170,7 @@ class DPSGDNodeFederated(Node):
""" """
comm_module = importlib.import_module(comm_configs["comm_package"]) comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys( comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None) self.addresses_filepath = comm_params.get("addresses_filepath", None)
self.communication = comm_class( self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
......
...@@ -47,8 +47,6 @@ class DPSGDWithPeerSampler(DPSGDNode): ...@@ -47,8 +47,6 @@ class DPSGDWithPeerSampler(DPSGDNode):
test_after=5, test_after=5,
train_evaluate_after=1, train_evaluate_after=1,
reset_optimizer=1, reset_optimizer=1,
centralized_train_eval=0,
centralized_test_eval=1,
peer_sampler_uid=-1, peer_sampler_uid=-1,
*args *args
): ):
...@@ -93,19 +91,10 @@ class DPSGDWithPeerSampler(DPSGDNode): ...@@ -93,19 +91,10 @@ class DPSGDWithPeerSampler(DPSGDNode):
Number of iterations after which the train loss is calculated Number of iterations after which the train loss is calculated
reset_optimizer : int reset_optimizer : int
1 if optimizer should be reset every communication round, else 0 1 if optimizer should be reset every communication round, else 0
centralized_train_eval : int
If set then the train set evaluation happens at the node with uid 0.
Note: If it is True then centralized_test_eval needs to be true as well!
centralized_test_eval : int
If set then the trainset evaluation happens at the node with uid 0
args : optional args : optional
Other arguments Other arguments
""" """
centralized_train_eval = centralized_train_eval == 1
centralized_test_eval = centralized_test_eval == 1
# If centralized_train_eval is True then centralized_test_eval needs to be true as well!
assert not centralized_train_eval or centralized_test_eval
total_threads = os.cpu_count() total_threads = os.cpu_count()
self.threads_per_proc = max( self.threads_per_proc = max(
...@@ -126,8 +115,6 @@ class DPSGDWithPeerSampler(DPSGDNode): ...@@ -126,8 +115,6 @@ class DPSGDWithPeerSampler(DPSGDNode):
test_after, test_after,
train_evaluate_after, train_evaluate_after,
reset_optimizer, reset_optimizer,
centralized_train_eval == 1,
centralized_test_eval == 1,
*args *args
) )
logging.info( logging.info(
......
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
import os import os
import random import random
from collections import deque from collections import deque
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from decentralizepy import utils from decentralizepy import utils
...@@ -134,8 +135,7 @@ class FederatedParameterServer(Node): ...@@ -134,8 +135,7 @@ class FederatedParameterServer(Node):
""" """
comm_module = importlib.import_module(comm_configs["comm_package"]) comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys( comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None) self.addresses_filepath = comm_params.get("addresses_filepath", None)
self.communication = comm_class( self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
...@@ -285,18 +285,13 @@ class FederatedParameterServer(Node): ...@@ -285,18 +285,13 @@ class FederatedParameterServer(Node):
self.current_workers = self.get_working_nodes() self.current_workers = self.get_working_nodes()
# Params to send to workers # Params to send to workers
# if this is the first iteration, use the init parameters, else use averaged params from last iteration to_send["params"] = self.model.state_dict()
if iteration == 0:
to_send = self.sharing.get_data_to_send()
to_send["CHANNEL"] = "WORKER_REQUEST" to_send["CHANNEL"] = "WORKER_REQUEST"
to_send["iteration"] = iteration to_send["iteration"] = iteration
# Notify workers # Notify workers
for worker in self.current_workers: for worker in self.current_workers:
self.communication.send( self.communication.send(worker, to_send)
worker, to_send
)
# Receive updates from current workers # Receive updates from current workers
while not self.received_from_all(): while not self.received_from_all():
...@@ -309,33 +304,15 @@ class FederatedParameterServer(Node): ...@@ -309,33 +304,15 @@ class FederatedParameterServer(Node):
# Average received updates # Average received updates
averaging_deque = dict() averaging_deque = dict()
total = dict()
for worker in self.current_workers: for worker in self.current_workers:
averaging_deque[worker] = self.peer_deques[worker] averaging_deque[worker] = self.peer_deques[worker]
for i, n in enumerate(averaging_deque): self.sharing._pre_step()
data = averaging_deque[n].popleft() self.sharing._averaging_server(averaging_deque)
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
data = self.sharing.deserialized_model(data)
for key, value in data.items():
if key in total:
total[key] += value
else:
total[key] = value
for key, value in total.items():
total[key] = total[key] / len(averaging_deque)
self.model.load_state_dict(total)
to_send = total
if iteration: if iteration:
with open( with open(
os.path.join( os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
self.log_dir, "{}_results.json".format(self.rank)),
"r", "r",
) as inf: ) as inf:
results_dict = json.load(inf) results_dict = json.load(inf)
...@@ -347,12 +324,9 @@ class FederatedParameterServer(Node): ...@@ -347,12 +324,9 @@ class FederatedParameterServer(Node):
"total_bytes": {}, "total_bytes": {},
"total_meta": {}, "total_meta": {},
"total_data_per_n": {}, "total_data_per_n": {},
"grad_mean": {},
"grad_std": {},
} }
results_dict["total_bytes"][iteration results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
+ 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"): if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][ results_dict["total_meta"][
...@@ -362,10 +336,6 @@ class FederatedParameterServer(Node): ...@@ -362,10 +336,6 @@ class FederatedParameterServer(Node):
results_dict["total_data_per_n"][ results_dict["total_data_per_n"][
iteration + 1 iteration + 1
] = self.communication.total_data ] = self.communication.total_data
if hasattr(self.sharing, "mean"):
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
if hasattr(self.sharing, "std"):
results_dict["grad_std"][iteration + 1] = self.sharing.std
rounds_to_train_evaluate -= 1 rounds_to_train_evaluate -= 1
...@@ -379,8 +349,7 @@ class FederatedParameterServer(Node): ...@@ -379,8 +349,7 @@ class FederatedParameterServer(Node):
"train_loss", "train_loss",
"Training Loss", "Training Loss",
"Communication Rounds", "Communication Rounds",
os.path.join( os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
self.log_dir, "{}_train_loss.png".format(self.rank)),
) )
rounds_to_test -= 1 rounds_to_test -= 1
...@@ -398,8 +367,7 @@ class FederatedParameterServer(Node): ...@@ -398,8 +367,7 @@ class FederatedParameterServer(Node):
global_epoch += change global_epoch += change
with open( with open(
os.path.join( os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of: ) as of:
json.dump(results_dict, of) json.dump(results_dict, of)
...@@ -411,8 +379,7 @@ class FederatedParameterServer(Node): ...@@ -411,8 +379,7 @@ class FederatedParameterServer(Node):
), ),
"w", "w",
) as of: ) as of:
json.dump( json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors() self.disconnect_neighbors()
logging.info("Storing final weight") logging.info("Storing final weight")
......
import logging
import math
import os
import queue
from random import Random
from threading import Lock, Thread
import numpy as np
import torch
from numpy.linalg import norm
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.OverlayNode import OverlayNode
class KNN(OverlayNode):
"""
This class defines the node for KNN Learning Node
"""
def similarityMetric(self, candidate):
logging.debug("A: {}".format(self.othersInfo[self.uid]))
logging.debug("B: {}".format(self.othersInfo[candidate]))
A = np.array(self.othersInfo[self.uid])
B = np.array(self.othersInfo[candidate])
return np.dot(A, B) / (norm(A) * norm(B))
def get_most_similar(self, candidates, to_keep=4):
if len(candidates) <= to_keep:
return candidates
cur_candidates = dict()
for i in candidates:
simil = round(self.similarityMetric(i), 3)
if simil not in cur_candidates:
cur_candidates[simil] = []
cur_candidates[simil].append(i)
similarity_scores = list(cur_candidates.keys())
similarity_scores.sort()
left_to_keep = to_keep
return_result = set()
for i in similarity_scores:
if left_to_keep >= len(cur_candidates[i]):
return_result.update(cur_candidates[i])
left_to_keep -= len(cur_candidates[i])
elif left_to_keep > 0:
return_result.update(
list(self.rng.sample(cur_candidates[i], left_to_keep))
)
left_to_keep = 0
break
else:
break
return return_result
def create_message_to_send(
self,
channel="KNNConstr",
boolean_flags=[],
add_my_info=False,
add_neighbor_info=False,
):
message = {"CHANNEL": channel, "KNNRound": self.knn_round}
for x in boolean_flags:
message[x] = True
if add_my_info:
message[self.uid] = self.othersInfo[self.uid]
if add_neighbor_info:
for neighbors in self.out_edges:
if neighbors in self.othersInfo:
message[neighbors] = self.othersInfo[neighbors]
return message
def receive_KNN_message(self):
return self.receive_channel("KNNConstr", block=False)
def process_init_receive(self, message):
self.mutex.acquire()
if "RESPONSE" in message[1]:
self.num_initializations += 1
else:
self.communication.send(
message[0],
self.create_message_to_send(
boolean_flags=["INIT", "RESPONSE"], add_my_info=True
),
)
x = (
message[0],
utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
)
self.othersInfo.update(x[1])
self.mutex.release()
def remove_meta_from_message(self, message):
return (
message[0],
utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
)
def process_candidates_without_lock(self, current_candidates, message):
if not self.exit_receiver:
message = (
message[0],
utils.remove_keys(
message[1], ["CHANNEL", "RESPONSE", "INIT", "KNNRound"]
),
)
self.othersInfo.update(message[1])
new_candidates = set(message[1].keys())
current_candidates = current_candidates.union(new_candidates)
if self.uid in current_candidates:
current_candidates.remove(self.uid)
self.out_edges = self.get_most_similar(current_candidates)
def send_response(self, message, add_neighbor_info=False, process_candidates=False):
self.mutex.acquire()
logging.debug("Responding to {}".format(message[0]))
self.communication.send(
message[0],
self.create_message_to_send(
boolean_flags=["RESPONSE"],
add_my_info=True,
add_neighbor_info=add_neighbor_info,
),
)
if process_candidates:
self.process_candidates_without_lock(set(self.out_edges), message)
self.mutex.release()
def receiver_thread(self):
knnBYEs = set()
self.num_initializations = 0
waiting_queue = queue.Queue()
while True:
if len(knnBYEs) == self.mapping.get_n_procs() - 1:
self.mutex.acquire()
if self.exit_receiver:
self.mutex.release()
logging.debug("Exiting thread")
return
self.mutex.release()
if self.num_initializations < self.initial_neighbors:
x = self.receive_KNN_message()
if x == None:
continue
elif "INIT" in x[1]:
self.process_init_receive(x)
else:
waiting_queue.put(x)
else:
logging.debug("Waiting for messages")
if waiting_queue.empty():
x = self.receive_KNN_message()
if x == None:
continue
else:
x = waiting_queue.get()
if "INIT" in x[1]:
logging.debug("A past INIT Message received from {}".format(x[0]))
self.process_init_receive(x)
elif "RESPONSE" in x[1]:
logging.debug(
"A response message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
x = self.remove_meta_from_message(x)
self.responseQueue.put(x)
elif "RANDOM_DISCOVERY" in x[1]:
logging.debug(
"A Random Discovery message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
self.send_response(
x, add_neighbor_info=False, process_candidates=False
)
elif "KNNBYE" in x[1]:
self.mutex.acquire()
knnBYEs.add(x[0])
logging.debug("{} KNN Byes received".format(knnBYEs))
if self.uid in x[1]["CLOSE"]:
self.in_edges.add(x[0])
self.mutex.release()
else:
logging.debug(
"A KNN sharing message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
self.send_response(
x, add_neighbor_info=True, process_candidates=True
)
def build_topology(self, rounds=30, random_nodes=4):
self.knn_round = 0
self.exit_receiver = False
t = Thread(target=self.receiver_thread)
t.start()
# Initializations : Send my dataset info to others
self.mutex.acquire()
initial_KNN_message = self.create_message_to_send(
boolean_flags=["INIT"], add_my_info=True
)
for x in self.out_edges:
self.communication.send(x, initial_KNN_message)
self.mutex.release()
for round in range(rounds):
self.knn_round = round
logging.info("Starting KNN Round {}".format(round))
self.mutex.acquire()
rand_neighbor = self.rng.choice(list(self.out_edges))
logging.debug("Random neighbor: {}".format(rand_neighbor))
self.communication.send(
rand_neighbor,
self.create_message_to_send(add_my_info=True, add_neighbor_info=True),
)
self.mutex.release()
logging.debug("Waiting for knn response from {}".format(rand_neighbor))
response = self.responseQueue.get(block=True)
logging.debug("Got response from random neighbor")
self.mutex.acquire()
random_candidates = set(
self.rng.sample(list(range(self.mapping.get_n_procs())), random_nodes)
)
req_responses = 0
for rc in random_candidates:
logging.debug("Current random discovery: {}".format(rc))
if rc not in self.othersInfo and rc != self.uid:
logging.debug("Sending discovery request to {}".format(rc))
self.communication.send(
rc,
self.create_message_to_send(boolean_flags=["RANDOM_DISCOVERY"]),
)
req_responses += 1
self.mutex.release()
while req_responses > 0:
logging.debug(
"Waiting for {} random discovery responses.".format(req_responses)
)
req_responses -= 1
random_discovery_response = self.responseQueue.get(block=True)
logging.debug(
"Received discovery response from {}".format(
random_discovery_response[0]
)
)
self.mutex.acquire()
self.othersInfo.update(random_discovery_response[1])
self.mutex.release()
self.mutex.acquire()
self.process_candidates_without_lock(
random_candidates.union(self.out_edges), response
)
self.mutex.release()
logging.info("Completed KNN Round {}".format(round))
logging.debug("OutNodes: {}".format(self.out_edges))
# Send out_edges and BYE to all
to_send = self.create_message_to_send(boolean_flags=["KNNBYE"])
logging.info("Sending KNNByes")
self.mutex.acquire()
self.exit_receiver = True
to_send["CLOSE"] = list(self.out_edges) # Optimize to only send Yes/No
for receiver in range(self.mapping.get_n_procs()):
if receiver != self.uid:
self.communication.send(receiver, to_send)
self.mutex.release()
logging.info("KNNByes Sent")
t.join()
logging.info("Receiver Thread Returned")
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
weights_store_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
initial_neighbors=4,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
"""
total_threads = os.cpu_count()
self.threads_per_proc = max(
math.floor(total_threads / mapping.procs_per_machine), 1
)
torch.set_num_threads(self.threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
weights_store_dir,
log_level,
test_after,
train_evaluate_after,
reset_optimizer,
*args
)
self.rng = Random()
self.rng.seed(self.uid + 100)
self.initial_neighbors = initial_neighbors
self.in_edges = set()
self.out_edges = set(
self.rng.sample(
list(self.graph.neighbors(self.uid)), self.initial_neighbors
)
)
self.responseQueue = queue.Queue()
self.mutex = Lock()
self.othersInfo = {self.uid: list(self.dataset.get_label_distribution())}
# ld = self.dataset.get_label_distribution()
# ld_keys = sorted(list(ld.keys()))
# self.othersInfo = {self.uid: []}
# for key in range(max(ld_keys) + 1):
# if key in ld:
# self.othersInfo[self.uid].append(ld[key])
# else:
# self.othersInfo[self.uid].append(0)
logging.info("Label Distributions: {}".format(self.othersInfo))
logging.info(
"Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
)
self.run()