Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
Commits on Source (6)
......@@ -9,4 +9,5 @@
**.egg-info
2021**
2022**
**/massif.out*
\ No newline at end of file
**/massif.out*
*swp
......@@ -2,6 +2,9 @@
decentralizepy
==============
decentralizepy is a framework for running distributed applications (particularly ML) on top of arbitrary topologies (decentralized, federated, parameter server).
It was primarily conceived for assessing scientific ideas on several aspects of distributed learning (communication efficiency, privacy, data heterogeneity etc.).
-------------------------
Setting up decentralizepy
-------------------------
......@@ -23,10 +26,14 @@ Setting up decentralizepy
pip install --upgrade pip
* On Mac M1, installing ``pyzmq`` fails with `pip`. Use `conda <https://conda.io>`_.
* Install decentralizepy for development. ::
* Install decentralizepy for development. (zsh) ::
pip3 install --editable .\[dev\]
* Install decentralizepy for development. (bash) ::
pip3 install --editable .[dev]
----------------
Running the code
----------------
......
......@@ -26,9 +26,9 @@ def get_stats(l):
return mean_dict, stdev_dict, min_dict, max_dict
def plot(means, stdevs, mins, maxs, title, label, loc):
def plot(means, stdevs, mins, maxs, title, label, loc, xlabel="communication rounds"):
plt.title(title)
plt.xlabel("communication rounds")
plt.xlabel(xlabel)
x_axis = np.array(list(means.keys()))
y_axis = np.array(list(means.values()))
err = np.array(list(stdevs.values()))
......@@ -37,6 +37,13 @@ def plot(means, stdevs, mins, maxs, title, label, loc):
plt.legend(loc=loc)
def replace_dict_key(d_org: dict, d_other: dict):
result = {}
for x, y in d_org.items():
result[d_other[x]] = y
return result
def plot_results(path, centralized, data_machine="machine0", data_node=0):
folders = os.listdir(path)
if centralized.lower() in ["true", "1", "t", "y", "yes"]:
......@@ -74,19 +81,54 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
with open(folder_path / data_machine / f"{data_node}_results.json", "r") as f:
main_data = json.load(f)
main_data = [main_data]
# Plotting bytes over time
plt.figure(10)
b_means, stdevs, mins, maxs = get_stats([x["total_bytes"] for x in results])
plot(b_means, stdevs, mins, maxs, "Total Bytes", folder, "lower right")
df = pd.DataFrame(
{
"mean": list(b_means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(b_means),
},
list(b_means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "total_bytes_" + folder + ".csv"), index_label="rounds"
)
# Plot Training loss
plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
correct_bytes = [b_means[x] for x in means]
df = pd.DataFrame(
{
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(11)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Training Loss",
folder,
"upper right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
)
......@@ -102,10 +144,24 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(12)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Testing Loss",
folder,
"upper right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
)
......@@ -121,9 +177,22 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
"total_bytes": correct_bytes,
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
columns=["mean", "std", "nr_nodes", "total_bytes"],
)
plt.figure(13)
means = replace_dict_key(means, b_means)
plot(
means,
stdevs,
mins,
maxs,
"Testing Accuracy",
folder,
"lower right",
"Total Bytes per node",
)
df.to_csv(
os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
......@@ -157,6 +226,15 @@ def plot_results(path, centralized, data_machine="machine0", data_node=0):
data_means[folder] = list(means.values())[0]
data_stdevs[folder] = list(stdevs.values())[0]
plt.figure(10)
plt.savefig(os.path.join(path, "total_bytes.png"), dpi=300)
plt.figure(11)
plt.savefig(os.path.join(path, "bytes_train_loss.png"), dpi=300)
plt.figure(12)
plt.savefig(os.path.join(path, "bytes_test_loss.png"), dpi=300)
plt.figure(13)
plt.savefig(os.path.join(path, "bytes_test_acc.png"), dpi=300)
plt.figure(1)
plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
plt.figure(2)
......
#!/bin/bash
script_path=$(realpath $(dirname $0))
decpy_path=/mnt/nfs/kirsten/Gitlab/jac_decentralizepy/decentralizepy/eval
# Working directory, where config files are read from and logs are written.
decpy_path=/mnt/nfs/$(whoami)/decpy_workingdir
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=/mnt/nfs/kirsten/Gitlab/tutorial/regular_16.txt
original_config=/mnt/nfs/kirsten/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config_celeba_sharing.ini
# Python interpreter
env_python=python3
# File regular_16.txt is available in /tutorial
graph=$decpy_path/regular_16.txt
# File config_celeba_sharing.ini is available in /tutorial
# In this config file, change addresses_filepath to correspond to your list of machines (example in /tutorial/ip.json)
original_config=$decpy_path/config_celeba_sharing.ini
# Local config file
config_file=/tmp/$(basename $original_config)
# Python script to be executed
eval_file=$script_path/testingPeerSampler.py
# General parameters
procs_per_machine=8
machines=2
iterations=5
test_after=2
eval_file=testingPeerSampler.py
log_level=INFO
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
......@@ -19,6 +33,8 @@ echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir
# Copy and manipulate the local config file
cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -ctr 0 -cte 0 -wsd $log_dir
\ No newline at end of file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wsd $log_dir
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.KNN import KNN
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items("DATASET")))
return config
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
config = read_ini(args.config_file)
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
procs_per_machine = args.procs_per_machine
l = Linear(n_machines, procs_per_machine)
m_id = args.machine_id
processes = []
for r in range(procs_per_machine):
processes.append(
mp.Process(
target=KNN,
args=[
r,
m_id,
l,
g,
my_config,
args.iterations,
args.log_dir,
args.weights_store_dir,
log_level[args.log_level],
args.test_after,
args.train_evaluate_after,
args.reset_optimizer,
],
)
)
for p in processes:
p.start()
for p in processes:
p.join()
......@@ -10,7 +10,6 @@ from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.DPSGDWithPeerSampler import DPSGDWithPeerSampler
from decentralizepy.node.PeerSampler import PeerSampler
# from decentralizepy.node.PeerSamplerDynamic import PeerSamplerDynamic
def read_ini(file_path):
......
......@@ -47,6 +47,7 @@ class TCP(Communication):
total_procs,
addresses_filepath,
offset=9000,
recv_timeout=50,
):
"""
Constructor
......@@ -79,11 +80,14 @@ class TCP(Communication):
self.machine_id = machine_id
self.mapping = mapping
self.offset = offset
self.recv_timeout = recv_timeout
self.uid = mapping.get_uid(rank, machine_id)
self.identity = str(self.uid).encode()
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.setsockopt(zmq.RCVTIMEO, self.recv_timeout)
self.router.setsockopt(zmq.ROUTER_MANDATORY, 1)
self.router.bind(self.addr(rank, machine_id))
self.total_data = 0
......@@ -170,7 +174,7 @@ class TCP(Communication):
id = str(neighbor).encode()
return id in self.peer_sockets
def receive(self):
def receive(self, block=True):
"""
Returns ONE message received.
......@@ -185,10 +189,19 @@ class TCP(Communication):
If received HELLO
"""
sender, recv = self.router.recv_multipart()
s, r = self.decrypt(sender, recv)
return s, r
while True:
try:
sender, recv = self.router.recv_multipart()
s, r = self.decrypt(sender, recv)
return s, r
except zmq.ZMQError as exc:
if exc.errno == zmq.EAGAIN:
if not block:
return None
else:
continue
else:
raise
def send(self, uid, data, encrypt=True):
"""
......
......@@ -114,6 +114,8 @@ class CIFAR10(Dataset):
test_batch_size,
)
self.num_classes = NUM_CLASSES
self.partition_niid = partition_niid
self.shards = shards
self.transform = transforms.Compose(
......
......@@ -230,6 +230,8 @@ class Celeba(Dataset):
self.IMAGES_DIR = utils.conditional_value(images_dir, "", None)
assert self.IMAGES_DIR != None
self.num_classes = NUM_CLASSES
if self.__training__:
self.load_trainset()
......
......@@ -52,6 +52,7 @@ class Dataset:
self.test_dir = utils.conditional_value(test_dir, "", None)
self.sizes = utils.conditional_value(sizes, "", None)
self.test_batch_size = utils.conditional_value(test_batch_size, "", 64)
self.num_classes = None
if self.sizes:
if type(self.sizes) == str:
self.sizes = eval(self.sizes)
......@@ -66,6 +67,20 @@ class Dataset:
else:
self.__testing__ = False
self.label_distribution = None
def get_label_distribution(self):
# Only supported for classification
if self.label_distribution == None:
self.label_distribution = [0 for _ in range(self.num_classes)]
tr_set = self.get_trainset()
for _, ys in tr_set:
for y in ys:
y_val = y.item()
self.label_distribution[y_val] += 1
return self.label_distribution
def get_trainset(self):
"""
Function to get the training set
......
......@@ -223,6 +223,8 @@ class Femnist(Dataset):
test_batch_size,
)
self.num_classes = NUM_CLASSES
if self.__training__:
self.load_trainset()
......
import logging
import math
import os
import queue
from random import Random
from threading import Lock, Thread
import numpy as np
import torch
from numpy.linalg import norm
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.OverlayNode import OverlayNode
class KNN(OverlayNode):
"""
This class defines the node for KNN Learning Node
"""
def similarityMetric(self, candidate):
logging.debug("A: {}".format(self.othersInfo[self.uid]))
logging.debug("B: {}".format(self.othersInfo[candidate]))
A = np.array(self.othersInfo[self.uid])
B = np.array(self.othersInfo[candidate])
return np.dot(A, B) / (norm(A) * norm(B))
def get_most_similar(self, candidates, to_keep=4):
if len(candidates) <= to_keep:
return candidates
cur_candidates = dict()
for i in candidates:
simil = round(self.similarityMetric(i), 3)
if simil not in cur_candidates:
cur_candidates[simil] = []
cur_candidates[simil].append(i)
similarity_scores = list(cur_candidates.keys())
similarity_scores.sort()
left_to_keep = to_keep
return_result = set()
for i in similarity_scores:
if left_to_keep >= len(cur_candidates[i]):
return_result.update(cur_candidates[i])
left_to_keep -= len(cur_candidates[i])
elif left_to_keep > 0:
return_result.update(
list(self.rng.sample(cur_candidates[i], left_to_keep))
)
left_to_keep = 0
break
else:
break
return return_result
def create_message_to_send(
self,
channel="KNNConstr",
boolean_flags=[],
add_my_info=False,
add_neighbor_info=False,
):
message = {"CHANNEL": channel, "KNNRound": self.knn_round}
for x in boolean_flags:
message[x] = True
if add_my_info:
message[self.uid] = self.othersInfo[self.uid]
if add_neighbor_info:
for neighbors in self.out_edges:
if neighbors in self.othersInfo:
message[neighbors] = self.othersInfo[neighbors]
return message
def receive_KNN_message(self):
return self.receive_channel("KNNConstr", block=False)
def process_init_receive(self, message):
self.mutex.acquire()
if "RESPONSE" in message[1]:
self.num_initializations += 1
else:
self.communication.send(
message[0],
self.create_message_to_send(
boolean_flags=["INIT", "RESPONSE"], add_my_info=True
),
)
x = (
message[0],
utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
)
self.othersInfo.update(x[1])
self.mutex.release()
def remove_meta_from_message(self, message):
return (
message[0],
utils.remove_keys(message[1], ["RESPONSE", "INIT", "CHANNEL", "KNNRound"]),
)
def process_candidates_without_lock(self, current_candidates, message):
if not self.exit_receiver:
message = (
message[0],
utils.remove_keys(
message[1], ["CHANNEL", "RESPONSE", "INIT", "KNNRound"]
),
)
self.othersInfo.update(message[1])
new_candidates = set(message[1].keys())
current_candidates = current_candidates.union(new_candidates)
if self.uid in current_candidates:
current_candidates.remove(self.uid)
self.out_edges = self.get_most_similar(current_candidates)
def send_response(self, message, add_neighbor_info=False, process_candidates=False):
self.mutex.acquire()
logging.debug("Responding to {}".format(message[0]))
self.communication.send(
message[0],
self.create_message_to_send(
boolean_flags=["RESPONSE"],
add_my_info=True,
add_neighbor_info=add_neighbor_info,
),
)
if process_candidates:
self.process_candidates_without_lock(set(self.out_edges), message)
self.mutex.release()
def receiver_thread(self):
knnBYEs = set()
self.num_initializations = 0
waiting_queue = queue.Queue()
while True:
if len(knnBYEs) == self.mapping.get_n_procs() - 1:
self.mutex.acquire()
if self.exit_receiver:
self.mutex.release()
logging.debug("Exiting thread")
return
self.mutex.release()
if self.num_initializations < self.initial_neighbors:
x = self.receive_KNN_message()
if x == None:
continue
elif "INIT" in x[1]:
self.process_init_receive(x)
else:
waiting_queue.put(x)
else:
logging.debug("Waiting for messages")
if waiting_queue.empty():
x = self.receive_KNN_message()
if x == None:
continue
else:
x = waiting_queue.get()
if "INIT" in x[1]:
logging.debug("A past INIT Message received from {}".format(x[0]))
self.process_init_receive(x)
elif "RESPONSE" in x[1]:
logging.debug(
"A response message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
x = self.remove_meta_from_message(x)
self.responseQueue.put(x)
elif "RANDOM_DISCOVERY" in x[1]:
logging.debug(
"A Random Discovery message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
self.send_response(
x, add_neighbor_info=False, process_candidates=False
)
elif "KNNBYE" in x[1]:
self.mutex.acquire()
knnBYEs.add(x[0])
logging.debug("{} KNN Byes received".format(knnBYEs))
if self.uid in x[1]["CLOSE"]:
self.in_edges.add(x[0])
self.mutex.release()
else:
logging.debug(
"A KNN sharing message received from {} from KNNRound {}".format(
x[0], x[1]["KNNRound"]
)
)
self.send_response(
x, add_neighbor_info=True, process_candidates=True
)
def build_topology(self, rounds=30, random_nodes=4):
self.knn_round = 0
self.exit_receiver = False
t = Thread(target=self.receiver_thread)
t.start()
# Initializations : Send my dataset info to others
self.mutex.acquire()
initial_KNN_message = self.create_message_to_send(
boolean_flags=["INIT"], add_my_info=True
)
for x in self.out_edges:
self.communication.send(x, initial_KNN_message)
self.mutex.release()
for round in range(rounds):
self.knn_round = round
logging.info("Starting KNN Round {}".format(round))
self.mutex.acquire()
rand_neighbor = self.rng.choice(list(self.out_edges))
logging.debug("Random neighbor: {}".format(rand_neighbor))
self.communication.send(
rand_neighbor,
self.create_message_to_send(add_my_info=True, add_neighbor_info=True),
)
self.mutex.release()
logging.debug("Waiting for knn response from {}".format(rand_neighbor))
response = self.responseQueue.get(block=True)
logging.debug("Got response from random neighbor")
self.mutex.acquire()
random_candidates = set(
self.rng.sample(list(range(self.mapping.get_n_procs())), random_nodes)
)
req_responses = 0
for rc in random_candidates:
logging.debug("Current random discovery: {}".format(rc))
if rc not in self.othersInfo and rc != self.uid:
logging.debug("Sending discovery request to {}".format(rc))
self.communication.send(
rc,
self.create_message_to_send(boolean_flags=["RANDOM_DISCOVERY"]),
)
req_responses += 1
self.mutex.release()
while req_responses > 0:
logging.debug(
"Waiting for {} random discovery responses.".format(req_responses)
)
req_responses -= 1
random_discovery_response = self.responseQueue.get(block=True)
logging.debug(
"Received discovery response from {}".format(
random_discovery_response[0]
)
)
self.mutex.acquire()
self.othersInfo.update(random_discovery_response[1])
self.mutex.release()
self.mutex.acquire()
self.process_candidates_without_lock(
random_candidates.union(self.out_edges), response
)
self.mutex.release()
logging.info("Completed KNN Round {}".format(round))
logging.debug("OutNodes: {}".format(self.out_edges))
# Send out_edges and BYE to all
to_send = self.create_message_to_send(boolean_flags=["KNNBYE"])
logging.info("Sending KNNByes")
self.mutex.acquire()
self.exit_receiver = True
to_send["CLOSE"] = list(self.out_edges) # Optimize to only send Yes/No
for receiver in range(self.mapping.get_n_procs()):
if receiver != self.uid:
self.communication.send(receiver, to_send)
self.mutex.release()
logging.info("KNNByes Sent")
t.join()
logging.info("Receiver Thread Returned")
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
weights_store_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
initial_neighbors=4,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
"""
total_threads = os.cpu_count()
self.threads_per_proc = max(
math.floor(total_threads / mapping.procs_per_machine), 1
)
torch.set_num_threads(self.threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
weights_store_dir,
log_level,
test_after,
train_evaluate_after,
reset_optimizer,
*args
)
self.rng = Random()
self.rng.seed(self.uid + 100)
self.initial_neighbors = initial_neighbors
self.in_edges = set()
self.out_edges = set(
self.rng.sample(
list(self.graph.neighbors(self.uid)), self.initial_neighbors
)
)
self.responseQueue = queue.Queue()
self.mutex = Lock()
self.othersInfo = {self.uid: list(self.dataset.get_label_distribution())}
# ld = self.dataset.get_label_distribution()
# ld_keys = sorted(list(ld.keys()))
# self.othersInfo = {self.uid: []}
# for key in range(max(ld_keys) + 1):
# if key in ld:
# self.othersInfo[self.uid].append(ld[key])
# else:
# self.othersInfo[self.uid].append(0)
logging.info("Label Distributions: {}".format(self.othersInfo))
logging.info(
"Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
)
self.run()
......@@ -26,14 +26,19 @@ class Node:
self.communication.init_connection(neighbor)
self.communication.send(neighbor, {"HELLO": self.uid, "CHANNEL": "CONNECT"})
def receive_channel(self, channel):
def receive_channel(self, channel, block=True):
if channel not in self.message_queue:
self.message_queue[channel] = deque()
if len(self.message_queue[channel]) > 0:
return self.message_queue[channel].popleft()
else:
sender, recv = self.communication.receive()
x = self.communication.receive(block=block)
if x == None:
assert not block
return None
sender, recv = x
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
......@@ -44,7 +49,11 @@ class Node:
if recv["CHANNEL"] not in self.message_queue:
self.message_queue[recv["CHANNEL"]] = deque()
self.message_queue[recv["CHANNEL"]].append((sender, recv))
sender, recv = self.communication.receive()
x = self.communication.receive(block=block)
if x == None:
assert not block
return None
sender, recv = x
logging.info(
"Received some message from {} with CHANNEL: {}".format(
sender, recv["CHANNEL"]
......
import importlib
import json
import logging
import math
import os
from collections import deque
import torch
from matplotlib import pyplot as plt
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
class OverlayNode(Node):
"""
This class defines the node on overlay graph
"""
def save_plot(self, l, label, title, xlabel, filename):
"""
Save Matplotlib plot. Clears previous plots.
Parameters
----------
l : dict
dict of x -> y. `x` must be castable to int.
label : str
label of the plot. Used for legend.
title : str
Header
xlabel : str
x-axis label
filename : str
Name of file to save the plot as.
"""
plt.clf()
y_axis = [l[key] for key in l.keys()]
x_axis = list(map(int, l.keys()))
plt.plot(x_axis, y_axis, label=label)
plt.xlabel(xlabel)
plt.title(title)
plt.savefig(filename)
def get_neighbors(self, node=None):
return self.my_neighbors
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def build_topology(self):
pass
def run(self):
"""
Start the decentralized learning
"""
self.testset = self.dataset.get_testset()
rounds_to_test = self.test_after
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
self.connect_neighbors()
logging.info("Connected to all neighbors")
self.build_topology()
logging.info("OutNodes: {}".format(self.out_edges))
logging.info("InNodes: {}".format(self.in_edges))
logging.info("Unifying edges")
self.out_edges = self.out_edges.union(self.in_edges)
self.my_neighbors = self.in_edges = set(self.out_edges)
logging.info("Total number of neighbor: {}".format(len(self.my_neighbors)))
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
self.iteration = iteration
self.trainer.train(self.dataset)
to_send = self.sharing.get_data_to_send()
to_send["CHANNEL"] = "DPSGD"
to_send["degree"] = len(self.in_edges)
assert len(self.out_edges) != 0
assert len(self.in_edges) != 0
for neighbor in self.out_edges:
self.communication.send(neighbor, to_send)
while not self.received_from_all():
sender, data = self.receive_DPSGD()
logging.info(
"Received Model from {} of iteration {}".format(
sender, data["iteration"]
)
)
if sender not in self.peer_deques:
self.peer_deques[sender] = deque()
self.peer_deques[sender].append(data)
averaging_deque = dict()
for neighbor in self.in_edges:
averaging_deque[neighbor] = self.peer_deques[neighbor]
self.sharing._averaging(averaging_deque)
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
) # Reset optimizer state
self.trainer.reset_optimizer(self.optimizer)
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if rounds_to_train_evaluate == 0:
logging.info("Evaluating on train set.")
rounds_to_train_evaluate = self.train_evaluate_after * change
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
)
if self.dataset.__testing__ and rounds_to_test == 0:
rounds_to_test = self.test_after * change
logging.info("Evaluating on test set.")
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
if global_epoch == 49:
change *= 2
global_epoch += change
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
# if self.model.shared_parameters_counter is not None:
# logging.info("Saving the shared parameter counts")
# with open(
# os.path.join(
# self.log_dir, "{}_shared_parameters.json".format(self.rank)
# ),
# "w",
# ) as of:
# json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
self.disconnect_neighbors()
# logging.info("Storing final weight")
# self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
logging.info("All neighbors disconnected. Process complete!")
def cache_fields(
self,
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
weights_store_dir,
test_after,
train_evaluate_after,
reset_optimizer,
):
"""
Instantiate object field with arguments.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
"""
self.rank = rank
self.machine_id = machine_id
self.graph = graph
self.mapping = mapping
self.uid = self.mapping.get_uid(rank, machine_id)
self.log_dir = log_dir
self.weights_store_dir = weights_store_dir
self.iterations = iterations
self.test_after = test_after
self.train_evaluate_after = train_evaluate_after
self.reset_optimizer = reset_optimizer
self.sent_disconnections = False
logging.info("Rank: %d", self.rank)
logging.info("type(graph): %s", str(type(self.rank)))
logging.info("type(mapping): %s", str(type(self.mapping)))
def init_comm(self, comm_configs):
"""
Instantiate communication module from config.
Parameters
----------
comm_configs : dict
Python dict containing communication config params
"""
comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None)
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
def instantiate(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
weights_store_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
*args
):
"""
Construct objects.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations.
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
"""
logging.info("Started process.")
self.init_log(log_dir, rank, log_level)
self.cache_fields(
rank,
machine_id,
mapping,
graph,
iterations,
log_dir,
weights_store_dir,
test_after,
train_evaluate_after,
reset_optimizer,
)
self.init_dataset_model(config["DATASET"])
self.init_optimizer(config["OPTIMIZER_PARAMS"])
self.init_trainer(config["TRAIN_PARAMS"])
self.init_comm(config["COMMUNICATION"])
self.message_queue = dict()
self.barrier = set()
self.my_neighbors = self.graph.neighbors(self.uid)
self.init_sharing(config["SHARING"])
self.peer_deques = dict()
self.connect_neighbors()
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
Returns
-------
bool
True if required data has been received, False otherwise
"""
for k in self.in_edges:
if (k not in self.peer_deques) or len(self.peer_deques[k]) == 0:
return False
return True
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
weights_store_dir=".",
log_level=logging.INFO,
test_after=5,
train_evaluate_after=1,
reset_optimizer=1,
*args
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) for which the model should be trained
log_dir : str
Logging directory
weights_store_dir : str
Directory in which to store model weights
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
test_after : int
Number of iterations after which the test loss and accuracy arecalculated
train_evaluate_after : int
Number of iterations after which the train loss is calculated
reset_optimizer : int
1 if optimizer should be reset every communication round, else 0
args : optional
Other arguments
"""
total_threads = os.cpu_count()
self.threads_per_proc = max(
math.floor(total_threads / mapping.procs_per_machine), 1
)
torch.set_num_threads(self.threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
weights_store_dir,
log_level,
test_after,
train_evaluate_after,
reset_optimizer,
*args
)
self.in_edges = set()
self.out_edges = set()
logging.info(
"Each proc uses %d threads out of %d.", self.threads_per_proc, total_threads
)
self.run()
import logging
from collections import OrderedDict
import torch
from collections import OrderedDict
from decentralizepy.sharing.Sharing import Sharing
def zeros_like_state_dict(state_dict):
"""
"""
Creates a new state dictionary such that it has same
layers (name and size) as the input state dictionary, but all values
are zero
......@@ -22,11 +22,12 @@ def zeros_like_state_dict(state_dict):
result_dict[tensor_name] = torch.zeros_like(tensor_values)
return result_dict
def get_dict_keys_and_check_matching(dict_1, dict_2):
"""
"""
Checks if keys of the two dictionaries match and
reutrns them if they do, otherwise raises ValueError
Parameters
----------
dict_1: dict
......@@ -40,11 +41,12 @@ def get_dict_keys_and_check_matching(dict_1, dict_2):
"""
keys = dict_1.keys()
if set(keys).difference(set(dict_2.keys())):
raise ValueError('Dictionaries must have matching keys')
raise ValueError("Dictionaries must have matching keys")
return keys
def subtract_state_dicts(_1, _2):
"""
"""
Subtracts one state dictionary from another
Parameters
......@@ -67,12 +69,13 @@ def subtract_state_dicts(_1, _2):
result_dict[key] = _1[key] - _2[key]
return result_dict
def self_add_state_dict(_1, _2, constant=1.):
def self_add_state_dict(_1, _2, constant=1.0):
"""
Scales one state dictionary by a constant and
adds it directly to another minimizing copies
created. Equivalent to operation `_1 += constant * _2`
Parameters
----------
_1: dict[str, torch.Tensor]
......@@ -93,11 +96,12 @@ def self_add_state_dict(_1, _2, constant=1.):
# Size checking is done by torch during the subtraction
_1[key] += constant * _2[key]
def flatten_state_dict(state_dict):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won't affect state dictionary
......@@ -107,10 +111,8 @@ def flatten_state_dict(state_dict):
A state dictionary to flatten
"""
return torch.cat([
tensor.flatten()\
for tensor in state_dict.values()
], axis=0)
return torch.cat([tensor.flatten() for tensor in state_dict.values()], axis=0)
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
......@@ -138,11 +140,11 @@ def unflatten_state_dict(flat_tensor, reference_state_dict):
start_index = 0
for tensor_name, tensor in reference_state_dict.items():
end_index = start_index + tensor.numel()
result[tensor_name] = flat_tensor[start_index:end_index].reshape(
tensor.shape)
result[tensor_name] = flat_tensor[start_index:end_index].reshape(tensor.shape)
start_index = end_index
return result
def serialize_sparse_tensor(tensor):
"""
Serializes sparse tensor by flattening it and
......@@ -158,6 +160,7 @@ def serialize_sparse_tensor(tensor):
values = flat[indices]
return values, indices
def deserialize_sparse_tensor(values, indices, shape):
"""
Deserializes tensor from its non-zero values and indices
......@@ -171,12 +174,12 @@ def deserialize_sparse_tensor(values, indices, shape):
Respective indices of non-zero entries of flattened original tensor
shape: torch.Size or tuple[*int]
Shape of the original tensor
"""
result = torch.zeros(size=shape)
if len(indices):
flat_result = result.flatten()
flat_result[indices] = values
flat_result = result.flatten()
flat_result[indices] = values
return result
......@@ -203,6 +206,7 @@ def topk_sparsification_tensor(tensor, alpha):
tensor[tensor_abs < -cutoff_value] = 0
return tensor
def topk_sparsification(state_dict, alpha):
"""
Performs topk sparsification of a state_dict
......@@ -221,17 +225,18 @@ def topk_sparsification(state_dict, alpha):
"""
flat_tensor = flatten_state_dict(state_dict)
return unflatten_state_dict(
topk_sparsification_tensor(flat_tensor, alpha),
state_dict)
topk_sparsification_tensor(flat_tensor, alpha), state_dict
)
def serialize_sparse_state_dict(state_dict):
with torch.no_grad():
concatted_tensors = torch.cat([
tensor.flatten()\
for tensor in state_dict.values()
], axis=0)
concatted_tensors = torch.cat(
[tensor.flatten() for tensor in state_dict.values()], axis=0
)
return serialize_sparse_tensor(concatted_tensors)
def deserialize_sparse_state_dict(values, indices, reference_state_dict):
with torch.no_grad():
keys = []
......@@ -310,16 +315,20 @@ class Choco(Sharing):
model,
dataset,
log_dir,
compress=False,
compression_package=None,
compression_class=None
compress,
compression_package,
compression_class,
)
self.step_size = step_size
self.alpha = alpha
logging.info("type(step_size): %s, value: %s",
str(type(self.step_size)), str(self.step_size))
logging.info("type(alpha): %s, value: %s",
str(type(self.alpha)), str(self.alpha))
logging.info(
"type(step_size): %s, value: %s",
str(type(self.step_size)),
str(self.step_size),
)
logging.info(
"type(alpha): %s, value: %s", str(type(self.alpha)), str(self.alpha)
)
model_state_dict = model.state_dict()
self.model_hat = zeros_like_state_dict(model_state_dict)
self.s = zeros_like_state_dict(model_state_dict)
......@@ -351,10 +360,10 @@ class Choco(Sharing):
"""
with torch.no_grad():
self.my_q = self._compress(subtract_state_dicts(
self.model.state_dict(), self.model_hat
))
self.my_q = self._compress(
subtract_state_dicts(self.model.state_dict(), self.model_hat)
)
def serialized_model(self):
"""
Convert self q to a dictionary. Here we can choose how much to share
......@@ -395,15 +404,16 @@ class Choco(Sharing):
indices = torch.tensor(m["indices"], dtype=torch.long)
values = torch.tensor(m["params"])
return deserialize_sparse_state_dict(
values, indices, self.model.state_dict())
values, indices, self.model.state_dict()
)
def _averaging(self, peer_deques):
"""
Averages the received model with the local model
"""
with torch.no_grad():
self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat
self_add_state_dict(self.model_hat, self.my_q) # x_hat = q_self + x_hat
weight_total = 0
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
......@@ -433,7 +443,8 @@ class Choco(Sharing):
self_add_state_dict(
total,
subtract_state_dicts(self.s, self.model_hat),
constant=self.step_size) # x = x + gamma * (s - x_hat)
constant=self.step_size,
) # x = x + gamma * (s - x_hat)
self.model.load_state_dict(total)
self._post_step()
......@@ -444,5 +455,4 @@ class Choco(Sharing):
Averages the received models of all working nodes
"""
raise NotImplementedError()
\ No newline at end of file
raise NotImplementedError()