Skip to content
Snippets Groups Projects
Commit fc6ee11c authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

reformatting

parent 1b7936b1
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ ip_machines=$nfs_home/configs/ip_addr_6Machines.json
m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
export PYTHONFAULTHANDLER=1
tests=("step_configs/config_celeba.ini" "step_configs/config_celeba_100.ini" "step_configs/config_celeba_fft.ini" "step_configs/config_celeba_wavelet.ini"
tests=("step_configs/config_celeba_partialmodel.ini" "step_configs/config_celeba_sharing.ini" "step_configs/config_celeba_fft.ini" "step_configs/config_celeba_wavelet.ini"
"step_configs/config_celeba_grow.ini" "step_configs/config_celeba_manualadapt.ini" "step_configs/config_celeba_randomalpha.ini"
"step_configs/config_celeba_randomalphainc.ini" "step_configs/config_celeba_roundrobin.ini" "step_configs/config_celeba_subsampling.ini"
"step_configs/config_celeba_topkrandom.ini" "step_configs/config_celeba_topkacc.ini" "step_configs/config_celeba_topkparam.ini")
......
......@@ -65,4 +65,3 @@ if __name__ == "__main__":
args.reset_optimizer,
],
)
print("after spawn")
......@@ -481,4 +481,3 @@ class Node:
)
self.run()
logging.info("Node finished running")
......@@ -27,6 +27,7 @@ def change_transformer_fft(x):
"""
return fft.rfft(x)
class FFT(PartialModel):
"""
This class implements the fft version of model sharing
......@@ -51,7 +52,7 @@ class FFT(PartialModel):
change_based_selection=True,
save_accumulated="",
accumulation=True,
accumulate_averaging_changes=False
accumulate_averaging_changes=False,
):
"""
Constructor
......@@ -94,8 +95,22 @@ class FFT(PartialModel):
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, change_transformer_fft, accumulate_averaging_changes
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
change_transformer_fft,
accumulate_averaging_changes,
)
self.change_based_selection = change_based_selection
......@@ -113,7 +128,9 @@ class FFT(PartialModel):
logging.info("Returning fft compressed model weights")
with torch.no_grad():
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
flat_fft = self.change_transformer(concated)
if self.change_based_selection:
......@@ -123,7 +140,10 @@ class FFT(PartialModel):
)
else:
_, index = torch.topk(
flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
flat_fft.abs(),
round(self.alpha * len(flat_fft)),
dim=0,
sorted=False,
)
return flat_fft[index], index
......@@ -233,7 +253,9 @@ class FFT(PartialModel):
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
......@@ -257,7 +279,9 @@ class FFT(PartialModel):
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(self.shapes[i])
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
......@@ -30,10 +30,10 @@ class PartialModel(Sharing):
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
accumulation = False,
accumulation=False,
save_accumulated="",
change_transformer = identity,
accumulate_averaging_changes = False
change_transformer=identity,
accumulate_averaging_changes=False,
):
"""
Constructor
......@@ -100,9 +100,11 @@ class PartialModel(Sharing):
tensors_to_cat.append(t)
self.init_model = torch.cat(tensors_to_cat, dim=0)
if self.accumulation:
self.model.accumulated_changes = torch.zeros_like(self.change_transformer(self.init_model))
self.model.accumulated_changes = torch.zeros_like(
self.change_transformer(self.init_model)
)
self.prev = self.init_model
if self.save_accumulated:
self.model_change_path = os.path.join(
self.log_dir, "model_change/{}".format(self.rank)
......@@ -295,7 +297,9 @@ class PartialModel(Sharing):
self.init_model = post_share_model
if self.accumulation:
if self.accumulate_averaging_changes:
self.model.accumulated_changes += self.change_transformer(self.init_model - self.prev)
self.model.accumulated_changes += self.change_transformer(
self.init_model - self.prev
)
self.prev = self.init_model
self.model.model_change = None
if self.save_accumulated:
......@@ -336,4 +340,4 @@ class PartialModel(Sharing):
Saves the change and the gradient values for every iteration
"""
self.save_vector(self.model.model_change, self.model_change_path)
\ No newline at end of file
self.save_vector(self.model.model_change, self.model_change_path)
......@@ -147,7 +147,9 @@ class Sharing:
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
......
......@@ -10,27 +10,29 @@ import torch
from decentralizepy.sharing.PartialModel import PartialModel
def change_transformer_wavelet(x, wavelet, level):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
def change_transformer_wavelet(x, wavelet, level):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff = pywt.wavedec(x, wavelet, level=level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
return torch.from_numpy(data.ravel())
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff = pywt.wavedec(x, wavelet, level=level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
return torch.from_numpy(data.ravel())
class Wavelet(PartialModel):
"""
......@@ -58,7 +60,7 @@ class Wavelet(PartialModel):
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes = False
accumulate_averaging_changes=False,
):
"""
Constructor
......@@ -107,9 +109,22 @@ class Wavelet(PartialModel):
self.level = level
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha, dict_ordered, save_shared,
metadata_cap, accumulation, save_accumulated, lambda x : change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
lambda x: change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes,
)
self.change_based_selection = change_based_selection
......@@ -132,13 +147,11 @@ class Wavelet(PartialModel):
"""
logging.info("Returning dwt compressed model weights")
logging.info("Returning wavelet compressed model weights")
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
data = self.change_transformer(concated)
logging.info("produced wavelet representation of current model")
if self.change_based_selection:
logging.info("changed based selection")
diff = self.model.model_change
_, index = torch.topk(
diff.abs(),
......@@ -146,7 +159,6 @@ class Wavelet(PartialModel):
dim=0,
sorted=False,
)
logging.info("finished change based selection")
else:
_, index = torch.topk(
data.abs(),
......@@ -167,7 +179,6 @@ class Wavelet(PartialModel):
Model converted to json dict
"""
logging.info("serializing wavelet model")
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
......@@ -175,7 +186,6 @@ class Wavelet(PartialModel):
topk, indices = self.apply_wavelet()
self.model.rewind_accumulation(indices)
logging.info("finished rewind")
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
......@@ -230,7 +240,6 @@ class Wavelet(PartialModel):
state_dict of received
"""
logging.info("deserializing wavelet model")
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
......@@ -265,7 +274,9 @@ class Wavelet(PartialModel):
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
......@@ -296,8 +307,9 @@ class Wavelet(PartialModel):
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(self.shapes[i])
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
......@@ -109,6 +109,7 @@ def write_args(args, path):
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
def identity(obj):
"""
Identity function
......@@ -121,4 +122,4 @@ def identity(obj):
obj
The same object
"""
return obj
\ No newline at end of file
return obj
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment