Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • sacs/decentralizepy
  • mvujas/decentralizepy
  • randl/decentralizepy
3 results
Show changes
import logging
import numpy as np
import torch
from decentralizepy.sharing.PartialModel import PartialModel
class TopKPlusRandom(PartialModel):
"""
This class implements partial model sharing with some random additions.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
compress=False,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
compress,
compression_package,
compression_class,
)
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients and put some extra random.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
logging.info("Returning topk gradients")
G = torch.abs(self.model.model_change)
std, mean = torch.std_mean(G, unbiased=False)
self.std = std.item()
self.mean = mean.item()
elements_to_pick = round(self.alpha / 2.0 * G.shape[0])
G_topK = torch.topk(G, min(G.shape[0], elements_to_pick), dim=0, sorted=False)
more_indices = np.arange(G.shape[0], dtype=int)
np.delete(more_indices, G_topK[1].numpy())
more_indices = np.random.choice(
more_indices, min(more_indices.shape[0], elements_to_pick)
)
G_topK0 = torch.cat([G_topK[0], G[more_indices]], dim=0)
G_topK1 = torch.cat([G_topK[1], torch.tensor(more_indices)], dim=0)
return G_topK0, G_topK1
import json
import logging
import os
import numpy as np
import pywt
import torch
from decentralizepy.sharing.PartialModel import PartialModel
def change_transformer_wavelet(x, wavelet, level):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff = pywt.wavedec(x, wavelet, level=level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
return torch.from_numpy(data.ravel())
class Wavelet(PartialModel):
"""
This class implements the wavelet version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
wavelet="haar",
level=4,
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes=False,
compress=False,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
wavelet: str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
change_based_selection : bool
use frequency change to select topk frequencies
save_accumulated : bool
True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
self.wavelet = wavelet
self.level = level
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
lambda x: change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes,
compress,
compression_package,
compression_class,
)
self.change_based_selection = change_based_selection
# Do a dummy transform to get the shape and coefficents slices
coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.wt_shape = data.shape
self.coeff_slices = coeff_slices
def apply_wavelet(self):
"""
Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected wavelet coefficients, b: Their indices.
"""
logging.info("Returning wavelet compressed model weights")
data = self.pre_share_model_transformed
if self.change_based_selection:
diff = self.model.model_change
_, index = torch.topk(
diff.abs(),
round(self.alpha * len(diff)),
dim=0,
sorted=False,
)
else:
_, index = torch.topk(
data.abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
index, _ = torch.sort(index)
return data[index], index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
m = dict()
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
)
return self.compress_data(m)
with torch.no_grad():
topk, indices = self.apply_wavelet()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
# is slow
shared_params[self.communication_round] = indices.tolist()
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["params"] = topk.numpy()
m["indices"] = indices.numpy().astype(np.int32)
m["send_partial"] = True
return self.compress_data(m)
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
m = self.decompress_data(m)
ret = dict()
if "send_partial" not in m:
params = m["params"]
params_tensor = torch.tensor(params)
ret["params"] = params_tensor
return ret
with torch.no_grad():
if not self.dict_ordered:
raise NotImplementedError
alpha = m["alpha"]
params_tensor = torch.tensor(m["params"])
indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
ret = dict()
ret["indices"] = indices_tensor
ret["params"] = params_tensor
ret["send_partial"] = True
return ret
def _averaging(self, peer_deques):
"""
Averages the received model with the local model
"""
with torch.no_grad():
total = None
weight_total = 0
wt_params = self.pre_share_model_transformed
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
# Metro-Hastings
weight = 1 / (max(len(peer_deques), degree) + 1)
weight_total += weight
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
# Metro-Hastings
total += (1 - weight_total) * wt_params
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
def _averaging_server(self, peer_deques):
"""
Averages the received models of all working nodes
"""
with torch.no_grad():
total = None
wt_params = self.pre_share_model_transformed
for i, n in enumerate(peer_deques):
data = peer_deques[n].popleft()
degree, iteration = data["degree"], data["iteration"]
del data["degree"]
del data["iteration"]
del data["CHANNEL"]
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
weight = 1 / len(peer_deques)
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
self._post_step()
self.communication_round += 1
import json
import os
from pathlib import Path
import torch
from decentralizepy.training.Training import Training
from decentralizepy.utils import conditional_value
class ChangeAccumulator(Training):
"""
This class implements the training module which also accumulates model change in a list.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
save_accumulated="",
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
save_accumulated : bool
True if accumulated weight change should be written to file
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
self.save_accumulated = conditional_value(save_accumulated, "", True)
self.communication_round = 0
if self.save_accumulated:
self.model_change_path = os.path.join(
self.log_dir, "model_change/{}".format(self.rank)
)
Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
self.model_val_path = os.path.join(
self.log_dir, "model_val/{}".format(self.rank)
)
Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
def save_vector(self, v, s):
"""
Saves the given vector to the file.
Parameters
----------
v : torch.tensor
The torch tensor to write to file
s : str
Path to folder to write to
"""
output_dict = dict()
output_dict["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v1 in self.model.state_dict().items():
shapes[k] = list(v1.shape)
output_dict["shapes"] = shapes
output_dict["tensor"] = v.tolist()
with open(
os.path.join(
s,
"{}.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(output_dict, of)
def save_change(self):
"""
Saves the change and the gradient values for every iteration
"""
tensors_to_cat = [
v.data.flatten() for _, v in self.model.accumulated_gradients[0].items()
]
change = torch.abs(torch.cat(tensors_to_cat, dim=0))
self.save_vector(change, self.model_change_path)
def save_model_params(self):
"""
Saves the change and the gradient values for every iteration
"""
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
params = torch.abs(torch.cat(tensors_to_cat, dim=0))
self.save_vector(params, self.model_val_path)
def train(self, dataset):
"""
One training iteration with accumulation of model change in model.accumulated_gradients.
Goes through the entire dataset.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
self.model.accumulated_gradients = []
self.init_model = {
k: v.data.clone().detach()
for k, v in zip(self.model.state_dict(), self.model.parameters())
}
super().train(dataset)
with torch.no_grad():
change = {
k: v.data.clone().detach() - self.init_model[k]
for k, v in zip(self.model.state_dict(), self.model.parameters())
}
self.model.accumulated_gradients.append(change)
if self.save_accumulated:
self.save_change()
self.save_model_params()
self.communication_round += 1
import logging
from decentralizepy.training.Training import Training
class GradientAccumulator(Training):
"""
This class implements the training module which also accumulates gradients of steps in a list.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
def trainstep(self, data, target):
"""
One training step on a minibatch.
Parameters
----------
data : any
Data item
target : any
Label
Returns
-------
int
Loss Value for the step
"""
self.model.zero_grad()
output = self.model(data)
loss_val = self.loss(output, target)
loss_val.backward()
logging.debug("Accumulating Gradients")
self.model.accumulated_gradients.append(
{
k: v.grad.clone().detach()
for k, v in zip(self.model.state_dict(), self.model.parameters())
}
)
self.optimizer.step()
return loss_val.item()
def train(self, dataset):
"""
One training iteration with accumulation of gradients in model.accumulated_gradients.
Goes through the entire dataset.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
self.model.accumulated_gradients = []
super().train(dataset)
......@@ -46,7 +46,7 @@ class Training:
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
full_epochs : bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
......
......@@ -69,12 +69,23 @@ def get_args():
type=str,
default="./{}".format(datetime.datetime.now().isoformat(timespec="minutes")),
)
parser.add_argument(
"-wsd",
"--weights_store_dir",
type=str,
default="./{}_ws".format(datetime.datetime.now().isoformat(timespec="minutes")),
)
parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
parser.add_argument("-ll", "--log_level", type=str, default="INFO")
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
parser.add_argument("-ta", "--test_after", type=int, default=5)
parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
parser.add_argument("-sm", "--server_machine", type=int, default=0)
parser.add_argument("-sr", "--server_rank", type=int, default=-1)
parser.add_argument("-wr", "--working_rate", type=float, default=1.0)
args = parser.parse_args()
return args
......@@ -97,12 +108,31 @@ def write_args(args, path):
"procs_per_machine": args.procs_per_machine,
"machines": args.machines,
"log_dir": args.log_dir,
"weights_store_dir": args.weights_store_dir,
"iterations": args.iterations,
"config_file": args.config_file,
"log_level": args.log_level,
"graph_file": args.graph_file,
"graph_type": args.graph_type,
"test_after": args.test_after,
"train_evaluate_after": args.train_evaluate_after,
"reset_optimizer": args.reset_optimizer,
"working_rate": args.working_rate,
}
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
def identity(obj):
"""
Identity function
Parameters
----------
obj
Some object
Returns
-------
obj
The same object
"""
return obj
[DATASET]
dataset_package = decentralizepy.datasets.CIFAR10
dataset_class = CIFAR10
model_class = LeNet
train_dir = /mnt/nfs/shared/CIFAR
test_dir = /mnt/nfs/shared/CIFAR
; python list of fractions below
sizes =
random_seed = 90
partition_niid = True
shards = 4
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = SGD
lr = 0.01
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
rounds = 3
full_epochs = False
batch_size = 8
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = /mnt/nfs/risharma/Gitlab/tutorial/ip.json
[SHARING]
sharing_package = decentralizepy.sharing.Sharing
sharing_class = Sharing
{
"0": "127.0.0.1"
}
\ No newline at end of file
16
0 12
0 14
0 15
1 8
1 3
1 6
2 9
2 10
2 5
3 1
3 11
3 9
4 9
4 12
4 13
5 2
5 6
5 7
6 1
6 5
6 7
7 5
7 6
7 14
8 1
8 13
8 14
9 2
9 3
9 4
10 2
10 11
10 13
11 10
11 3
11 15
12 0
12 4
12 15
13 8
13 10
13 4
14 0
14 8
14 7
15 0
15 11
15 12
#!/bin/bash
decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config.ini
procs_per_machine=16
machines=1
iterations=80
test_after=20
eval_file=testingPeerSampler.py
log_level=INFO
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir
cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -wsd $log_dir
\ No newline at end of file
#!/bin/bash
decpy_path=/mnt/nfs/risharma/Gitlab/decentralizepy/eval
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=/mnt/nfs/risharma/Gitlab/tutorial/96_regular.edges
original_config=/mnt/nfs/risharma/Gitlab/tutorial/config_celeba_sharing.ini
config_file=~/tmp/config.ini
procs_per_machine=16
machines=1
iterations=80
test_after=20
eval_file=testingFederated.py
log_level=INFO
server_rank=-1
server_machine=0
working_rate=0.5
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
echo M is $m
log_dir=$(date '+%Y-%m-%dT%H:%M')/machine$m
mkdir -p $log_dir
cp $original_config $config_file
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level -ctr 0 -cte 0 -wsd $log_dir -sm $server_machine -sr $server_rank -wr $working_rate
\ No newline at end of file