From 8cf2ee7f299a85e8fc1c5901590eb9796971ce07 Mon Sep 17 00:00:00 2001
From: kirsten <elisabeth.kirsten@stud.tu-darmstadt.de>
Date: Thu, 8 Sep 2022 11:55:55 +0200
Subject: [PATCH] add working rate as parameter

---
 eval/run.sh                 | 3 +--
 eval/run_xtimes_cifar.sh    | 3 ++-
 eval/testingFederated.py    | 5 +----
 src/decentralizepy/utils.py | 2 ++
 4 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/eval/run.sh b/eval/run.sh
index b96f8b8..9d71d96 100755
--- a/eval/run.sh
+++ b/eval/run.sh
@@ -11,8 +11,7 @@ procs_per_machine=8
 machines=2
 iterations=5
 test_after=2
-eval_file=testingFederated.py
-#eval_file=testingPeerSampler.py
+eval_file=testingPeerSampler.py
 log_level=INFO
 
 m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 038fef1..a5b9eca 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -44,6 +44,7 @@ machines=6
 global_epochs=100
 eval_file=testingFederated.py
 log_level=INFO
+working_rate=0.1
 
 ip_machines=$nfs_home/configs/ip_addr_6Machines.json
 
@@ -104,7 +105,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wr $working_rate
     echo $i is done
     sleep 200
     echo end of sleep
diff --git a/eval/testingFederated.py b/eval/testingFederated.py
index b9cdeca..c1b0c4d 100644
--- a/eval/testingFederated.py
+++ b/eval/testingFederated.py
@@ -54,9 +54,6 @@ if __name__ == "__main__":
     sm = args.server_machine
     sr = args.server_rank
 
-    # TODO
-    working_fraction = 1.0
-
     processes = []
     if sm == m_id:
         processes.append(
@@ -74,7 +71,7 @@ if __name__ == "__main__":
                     log_level[args.log_level],
                     args.test_after,
                     args.train_evaluate_after,
-                    working_fraction,
+                    args.working_rate,
                 ],
             )
         )
diff --git a/src/decentralizepy/utils.py b/src/decentralizepy/utils.py
index aad8dda..03b36f4 100644
--- a/src/decentralizepy/utils.py
+++ b/src/decentralizepy/utils.py
@@ -87,6 +87,7 @@ def get_args():
     parser.add_argument("-cte", "--centralized_test_eval", type=int, default=0)
     parser.add_argument("-sm", "--server_machine", type=int, default=0)
     parser.add_argument("-sr", "--server_rank", type=int, default=-1)
+    parser.add_argument("-wr", "--working_rate", type=float, default=1.0)
 
     args = parser.parse_args()
     return args
@@ -120,6 +121,7 @@ def write_args(args, path):
         "reset_optimizer": args.reset_optimizer,
         "centralized_train_eval": args.centralized_train_eval,
         "centralized_test_eval": args.centralized_test_eval,
+        "working_rate": args.working_rate,
     }
     with open(os.path.join(path, "args.json"), "w") as of:
         json.dump(data, of)
-- 
GitLab