From 6ec0bd267ba9eb244c285086fef78c677a456d57 Mon Sep 17 00:00:00 2001 From: Rishi Sharma <rishi.sharma@epfl.ch> Date: Wed, 1 Dec 2021 16:32:10 +0100 Subject: [PATCH] Partial model sharing fix - config + test_after --- eval/config.ini | 2 +- eval/main.ipynb | 4 ++-- eval/testing.py | 1 + src/decentralizepy/sharing/Sharing.py | 1 + src/decentralizepy/utils.py | 2 ++ 5 files changed, 7 insertions(+), 3 deletions(-) diff --git a/eval/config.ini b/eval/config.ini index 7df6160..5f495e5 100644 --- a/eval/config.ini +++ b/eval/config.ini @@ -2,7 +2,7 @@ dataset_package = decentralizepy.datasets.Femnist dataset_class = Femnist model_class = CNN -n_procs = 16 +n_procs = 4 train_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/per_user_data/train test_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/data/test ; python list of fractions below diff --git a/eval/main.ipynb b/eval/main.ipynb index 3fecdf4..fd6f098 100644 --- a/eval/main.ipynb +++ b/eval/main.ipynb @@ -509,14 +509,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from decentralizepy.graphs.FullyConnected import FullyConnected\n", "\n", "s = FullyConnected(96)\n", - "s.write_graph_to_file('9_node_fullyConnected.edges')" + "s.write_graph_to_file('96_node_fullyConnected.edges')" ] }, { diff --git a/eval/testing.py b/eval/testing.py index 563b73d..493c042 100644 --- a/eval/testing.py +++ b/eval/testing.py @@ -61,5 +61,6 @@ if __name__ == "__main__": args.iterations, args.log_dir, log_level[args.log_level], + args.test_after ], ) diff --git a/src/decentralizepy/sharing/Sharing.py b/src/decentralizepy/sharing/Sharing.py index 0e78c15..5f3ce45 100644 --- a/src/decentralizepy/sharing/Sharing.py +++ b/src/decentralizepy/sharing/Sharing.py @@ -63,6 +63,7 @@ class Sharing: degree = data["degree"] del data["degree"] self.peer_deques[sender].append((degree, self.deserialized_model(data))) + logging.info("Deserialized received model from {}".format(sender)) logging.info("Starting model averaging after receiving from all neighbors") total = dict() diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py index 00cfd4f..f9cc5a2 100644 --- a/src/decentralizepy/utils.py +++ b/src/decentralizepy/utils.py @@ -28,6 +28,7 @@ def get_args(): parser.add_argument("-ll", "--log_level", type=str, default="INFO") parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges") parser.add_argument("-gt", "--graph_type", type=str, default="edges") + parser.add_argument("-ta", "--test_after", type=int, default = 5) args = parser.parse_args() return args @@ -44,6 +45,7 @@ def write_args(args, path): "log_level": args.log_level, "graph_file": args.graph_file, "graph_type": args.graph_type, + "test_after": args.test_after } with open(os.path.join(path, "args.json"), "w") as of: json.dump(data, of) -- GitLab