From b8731d2c3492a57dc42b2c147b4f361027c00235 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 30 Mar 2022 18:40:54 +0000
Subject: [PATCH] Shared parameter counter

---
 eval/plot.py                                  |  17 +--
 eval/plot_percentile.py                       | 125 ++++++++++++++++++
 src/decentralizepy/models/Model.py            |   1 +
 src/decentralizepy/node/Node.py               |   7 +-
 src/decentralizepy/sharing/FFT.py             |   2 +-
 src/decentralizepy/sharing/PartialModel.py    |   3 +
 .../sharing/RoundRobinPartial.py              |   1 +
 src/decentralizepy/sharing/SubSampling.py     |   2 +
 src/decentralizepy/sharing/TopKParams.py      |   2 +-
 src/decentralizepy/sharing/Wavelet.py         |   2 +-
 10 files changed, 150 insertions(+), 12 deletions(-)
 create mode 100644 eval/plot_percentile.py

diff --git a/eval/plot.py b/eval/plot.py
index f3f82c7..a7c8dd9 100644
--- a/eval/plot.py
+++ b/eval/plot.py
@@ -5,6 +5,7 @@ import sys
 import numpy as np
 import pandas as pd
 from matplotlib import pyplot as plt
+import torch
 
 
 def get_stats(l):
@@ -152,15 +153,15 @@ def plot_results(path):
         data_stdevs[folder] = list(stdevs.values())[0]
 
     plt.figure(1)
-    plt.savefig(os.path.join(path, "train_loss.png"))
+    plt.savefig(os.path.join(path, "train_loss.png"), dpi=300)
     plt.figure(2)
-    plt.savefig(os.path.join(path, "test_loss.png"))
+    plt.savefig(os.path.join(path, "test_loss.png"), dpi=300)
     plt.figure(3)
-    plt.savefig(os.path.join(path, "test_acc.png"))
+    plt.savefig(os.path.join(path, "test_acc.png"), dpi=300)
     plt.figure(6)
-    plt.savefig(os.path.join(path, "grad_std.png"))
+    plt.savefig(os.path.join(path, "grad_std.png"), dpi=300)
     plt.figure(7)
-    plt.savefig(os.path.join(path, "grad_mean.png"))
+    plt.savefig(os.path.join(path, "grad_mean.png"), dpi=300)
     # Plot total_bytes
     plt.figure(4)
     plt.title("Data Shared")
@@ -174,7 +175,7 @@ def plot_results(path):
     plt.ylabel("Total data shared in MBs")
     plt.xlabel("Fraction of Model Shared")
     plt.xticks(x_pos, list(bytes_means.keys()))
-    plt.savefig(os.path.join(path, "data_shared.png"))
+    plt.savefig(os.path.join(path, "data_shared.png"), dpi=300)
 
     # Plot stacked_bytes
     plt.figure(5)
@@ -198,7 +199,7 @@ def plot_results(path):
     plt.ylabel("Data shared in MBs")
     plt.xlabel("Fraction of Model Shared")
     plt.xticks(x_pos, list(meta_means.keys()))
-    plt.savefig(os.path.join(path, "parameters_metadata.png"))
+    plt.savefig(os.path.join(path, "parameters_metadata.png"), dpi=300)
 
 
 def plot_parameters(path):
@@ -237,4 +238,4 @@ def plot_parameters(path):
 if __name__ == "__main__":
     assert len(sys.argv) == 2
     plot_results(sys.argv[1])
-    # plot_parameters(sys.argv[1])
+    # plot_parameters(sys.argv[1])
\ No newline at end of file
diff --git a/eval/plot_percentile.py b/eval/plot_percentile.py
new file mode 100644
index 0000000..e8f0da2
--- /dev/null
+++ b/eval/plot_percentile.py
@@ -0,0 +1,125 @@
+import json
+import os
+import sys
+
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+import torch
+
+
+def get_stats(l):
+    assert len(l) > 0
+    mean_dict, stdev_dict, min_dict, max_dict = {}, {}, {}, {}
+    for key in l[0].keys():
+        all_nodes = [i[key] for i in l]
+        all_nodes = np.array(all_nodes)
+        mean = np.mean(all_nodes)
+        std = np.std(all_nodes)
+        min = np.min(all_nodes)
+        max = np.max(all_nodes)
+        mean_dict[int(key)] = mean
+        stdev_dict[int(key)] = std
+        min_dict[int(key)] = min
+        max_dict[int(key)] = max
+    return mean_dict, stdev_dict, min_dict, max_dict
+
+
+def plot(means, stdevs, mins, maxs, title, label, loc):
+    plt.title(title)
+    plt.xlabel("communication rounds")
+    x_axis = list(means.keys())
+    y_axis = list(means.values())
+    err = list(stdevs.values())
+    plt.errorbar(x_axis, y_axis, yerr=err, label=label)
+    plt.legend(loc=loc)
+
+
+def plot_results(path):
+    """
+    plots the percentiles
+    Based on plot.py
+    Parameters
+    ----------
+    path
+        path to the folders from which to create the percentiles plots
+
+    """
+    folders = os.listdir(path)
+    folders.sort()
+    print("Reading folders from: ", path)
+    print("Folders: ", folders)
+    for folder in folders:
+        folder_path = os.path.join(path, folder)
+        if not os.path.isdir(folder_path):
+            continue
+        results = []
+        all_shared_params = []
+        machine_folders = os.listdir(folder_path)
+        for machine_folder in machine_folders:
+            mf_path = os.path.join(folder_path, machine_folder)
+            if not os.path.isdir(mf_path):
+                continue
+            files = os.listdir(mf_path)
+            shared_params = [f for f in files if f.endswith("_shared_parameters.json")]
+            files = [f for f in files if f.endswith("_results.json")]
+            for f in files:
+                filepath = os.path.join(mf_path, f)
+                with open(filepath, "r") as inf:
+                    results.append(json.load(inf))
+            for sp in shared_params:
+                filepath = os.path.join(mf_path, sp)
+                with open(filepath, "r") as spf:
+                    all_shared_params.append(np.array(json.load(spf), dtype = np.int32))
+
+        # Plot Training loss
+        plt.figure(1)
+        # Average of the shared parameters
+        mean = np.mean(all_shared_params, axis=0)
+        std = np.std(all_shared_params, axis=0)
+        with open(
+                os.path.join(path, "shared_params_avg_"+folder+".json"), "w"
+        ) as mf:
+            json.dump(mean.tolist(), mf)
+
+        with open(
+                os.path.join(path, "shared_params_std_"+folder+".json"), "w"
+        ) as sf:
+            json.dump(std.tolist(), sf)
+
+        # copy jupyter notebook code
+        percentile = np.percentile(mean, np.arange(0, 100, 1))
+        plt.plot(np.arange(0, 100, 1), percentile, label=folder)
+        plt.title('Shared parameters Percentiles')
+        # plt.ylabel("Absolute frequency value")
+        plt.xlabel("Percentiles")
+        plt.xticks(np.arange(0, 110, 10))
+        plt.legend(loc="lower right")
+
+        plt.figure(2)
+        sort = torch.sort(torch.tensor(mean)).values
+        print(sort)
+        length = sort.shape[0]
+        length = int(length / 20)
+        bins = [torch.sum(sort[length * i: length * (i + 1)]).item() for i in range(20)]
+        total = np.sum(bins)
+        perc = bins / total #np.divide(bins, total)
+        print(perc)
+        plt.bar(np.arange(0, 97.5, 5), perc, width=5, align='edge',
+                label=folder)
+
+        plt.title('Shared parameters Percentiles')
+        # plt.ylabel("Absolute frequency value")
+        plt.xlabel("Percentiles")
+        plt.legend(loc="lower right")
+        plt.savefig(os.path.join(path, f"percentiles_histogram_{folder}.png"), dpi=300)
+        plt.clf()
+        plt.cla()
+
+    plt.figure(1)
+    plt.savefig(os.path.join(path, "percentiles.png"), dpi=300)
+
+
+if __name__ == "__main__":
+    assert len(sys.argv) == 2
+    plot_results(sys.argv[1])
\ No newline at end of file
diff --git a/src/decentralizepy/models/Model.py b/src/decentralizepy/models/Model.py
index f3635a9..5edba7f 100644
--- a/src/decentralizepy/models/Model.py
+++ b/src/decentralizepy/models/Model.py
@@ -18,6 +18,7 @@ class Model(nn.Module):
         self._param_count_ot = None
         self._param_count_total = None
         self.accumulated_changes = None
+        self.shared_parameters_counter = None
 
     def count_params(self, only_trainable=False):
         """
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index e85159c..57fa8d5 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -419,7 +419,12 @@ class Node:
                 os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
             ) as of:
                 json.dump(results_dict, of)
-
+        if self.model.shared_parameters_counter is not None:
+            logging.info("Saving the shared parameter counts")
+            with open(
+                    os.path.join(self.log_dir, "{}_shared_parameters.json".format(self.rank)), "w"
+            ) as of:
+                json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
         self.communication.disconnect_neighbors()
         logging.info("All neighbors disconnected. Process complete!")
 
diff --git a/src/decentralizepy/sharing/FFT.py b/src/decentralizepy/sharing/FFT.py
index 0c0172f..d912807 100644
--- a/src/decentralizepy/sharing/FFT.py
+++ b/src/decentralizepy/sharing/FFT.py
@@ -163,7 +163,7 @@ class FFT(PartialModel):
 
         with torch.no_grad():
             topk, indices = self.apply_fft()
-
+            self.model.shared_parameters_counter[indices] += 1
             self.model.rewind_accumulation(indices)
 
             if self.save_shared:
diff --git a/src/decentralizepy/sharing/PartialModel.py b/src/decentralizepy/sharing/PartialModel.py
index 97c702b..e898bf9 100644
--- a/src/decentralizepy/sharing/PartialModel.py
+++ b/src/decentralizepy/sharing/PartialModel.py
@@ -126,6 +126,8 @@ class PartialModel(Sharing):
             )
             Path(self.folder_path).mkdir(parents=True, exist_ok=True)
 
+        self.model.shared_parameters_counter = torch.zeros(self.change_transformer(self.init_model).shape[0], dtype = torch.int32)
+
     def extract_top_gradients(self):
         """
         Extract the indices and values of the topK gradients.
@@ -162,6 +164,7 @@ class PartialModel(Sharing):
 
         with torch.no_grad():
             _, G_topk = self.extract_top_gradients()
+            self.model.shared_parameters_counter[G_topk] += 1
             if self.accumulation:
                 self.model.rewind_accumulation(G_topk)
             if self.save_shared:
diff --git a/src/decentralizepy/sharing/RoundRobinPartial.py b/src/decentralizepy/sharing/RoundRobinPartial.py
index 6b4f517..c5288a5 100644
--- a/src/decentralizepy/sharing/RoundRobinPartial.py
+++ b/src/decentralizepy/sharing/RoundRobinPartial.py
@@ -84,6 +84,7 @@ class RoundRobinPartial(Sharing):
             block_end = min(T.shape[0], (self.current_block + 1) * self.block_size)
             self.current_block = (self.current_block + 1) % self.num_blocks
             T_send = T[block_start:block_end]
+            self.model.shared_parameters_counter[block_start:block_end] += 1
             logging.info("Range sending: {}-{}".format(block_start, block_end))
             logging.info("Generating dictionary to send")
 
diff --git a/src/decentralizepy/sharing/SubSampling.py b/src/decentralizepy/sharing/SubSampling.py
index 5ec0c44..6221714 100644
--- a/src/decentralizepy/sharing/SubSampling.py
+++ b/src/decentralizepy/sharing/SubSampling.py
@@ -131,6 +131,7 @@ class SubSampling(Sharing):
                 <= self.alpha
             )
             subsample = concated[binary_mask]
+            self.model.shared_parameters_counter[binary_mask] += 1
             # logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
             return (subsample, curr_seed, self.alpha)
         else:
@@ -147,6 +148,7 @@ class SubSampling(Sharing):
                     )
                     <= self.alpha
                 )
+                # TODO: support shared_parameters_counter
                 selected = flat[binary_mask]
                 values_list.append(selected)
                 off += selected.size(dim=0)
diff --git a/src/decentralizepy/sharing/TopKParams.py b/src/decentralizepy/sharing/TopKParams.py
index c6535ce..02531f1 100644
--- a/src/decentralizepy/sharing/TopKParams.py
+++ b/src/decentralizepy/sharing/TopKParams.py
@@ -128,7 +128,7 @@ class TopKParams(Sharing):
 
         with torch.no_grad():
             values, index, offsets = self.extract_top_params()
-
+            self.model.shared_parameters_counter[index] += 1
             if self.save_shared:
                 shared_params = dict()
                 shared_params["order"] = list(self.model.state_dict().keys())
diff --git a/src/decentralizepy/sharing/Wavelet.py b/src/decentralizepy/sharing/Wavelet.py
index 363a487..9a02a64 100644
--- a/src/decentralizepy/sharing/Wavelet.py
+++ b/src/decentralizepy/sharing/Wavelet.py
@@ -184,7 +184,7 @@ class Wavelet(PartialModel):
 
         with torch.no_grad():
             topk, indices = self.apply_wavelet()
-
+            self.model.shared_parameters_counter[indices] += 1
             self.model.rewind_accumulation(indices)
             if self.save_shared:
                 shared_params = dict()
-- 
GitLab