Skip to content
Snippets Groups Projects
Commit ed4148ea authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

reformatting

parent 1a56aadc
No related branches found
No related tags found
1 merge request!3FFT Wavelets and more
...@@ -3,8 +3,8 @@ import os ...@@ -3,8 +3,8 @@ import os
import sys import sys
import numpy as np import numpy as np
from matplotlib import pyplot as plt
import pandas as pd import pandas as pd
from matplotlib import pyplot as plt
def get_stats(l): def get_stats(l):
...@@ -62,20 +62,50 @@ def plot_results(path): ...@@ -62,20 +62,50 @@ def plot_results(path):
plt.figure(1) plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results]) means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
df = pd.DataFrame({"mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)]*len(means)}, list(means.keys()), columns=["mean", "std", "nr_nodes"]) df = pd.DataFrame(
df.to_csv(os.path.join(path, "train_loss_" + folder + ".csv")) {
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "train_loss_" + folder + ".csv"), index_label="rounds"
)
# Plot Testing loss # Plot Testing loss
plt.figure(2) plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results]) means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right") plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
df = pd.DataFrame({"mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)]*len(means)}, list(means.keys()), columns=["mean", "std", "nr_nodes"]) df = pd.DataFrame(
df.to_csv(os.path.join(path, "test_loss_" + folder + ".csv")) {
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "test_loss_" + folder + ".csv"), index_label="rounds"
)
# Plot Testing Accuracy # Plot Testing Accuracy
plt.figure(3) plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results]) means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right") plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
df = pd.DataFrame({"mean": list(means.values()), "std": list(stdevs.values()), "nr_nodes": [len(results)]*len(means)}, list(means.keys()), columns=["mean", "std", "nr_nodes"]) df = pd.DataFrame(
df.to_csv(os.path.join(path, "test_acc_" + folder + ".csv")) {
"mean": list(means.values()),
"std": list(stdevs.values()),
"nr_nodes": [len(results)] * len(means),
},
list(means.keys()),
columns=["mean", "std", "nr_nodes"],
)
df.to_csv(
os.path.join(path, "test_acc_" + folder + ".csv"), index_label="rounds"
)
plt.figure(6) plt.figure(6)
means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results]) means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
plot( plot(
......
...@@ -56,4 +56,4 @@ class Model(nn.Module): ...@@ -56,4 +56,4 @@ class Model(nn.Module):
""" """
if self.accumulated_changes is not None: if self.accumulated_changes is not None:
self.accumulated_changes[indices] = 0.0 self.accumulated_changes[indices] = 0.0
\ No newline at end of file
...@@ -167,7 +167,7 @@ class ChangeAccumulator(Training): ...@@ -167,7 +167,7 @@ class ChangeAccumulator(Training):
else: else:
flats = [v.data.flatten() for _, v in self.init_model.items()] flats = [v.data.flatten() for _, v in self.init_model.items()]
flat = torch.cat(flats) flat = torch.cat(flats)
self.model.accumulated_changes += (flat - self.prev) self.model.accumulated_changes += flat - self.prev
self.prev = flat self.prev = flat
super().train(dataset) super().train(dataset)
...@@ -181,7 +181,7 @@ class ChangeAccumulator(Training): ...@@ -181,7 +181,7 @@ class ChangeAccumulator(Training):
flat_change = torch.cat(flats_change) flat_change = torch.cat(flats_change)
# flatten does not copy data if input is already flattened # flatten does not copy data if input is already flattened
# however cat copies # however cat copies
change = {"flat" : self.model.accumulated_changes + flat_change} change = {"flat": self.model.accumulated_changes + flat_change}
self.model.accumulated_gradients.append(change) self.model.accumulated_gradients.append(change)
......
...@@ -88,7 +88,9 @@ class FrequencyAccumulator(Training): ...@@ -88,7 +88,9 @@ class FrequencyAccumulator(Training):
""" """
with torch.no_grad(): with torch.no_grad():
self.model.accumulated_gradients = [] self.model.accumulated_gradients = []
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()] tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0) concated = torch.cat(tensors_to_cat, dim=0)
self.init_model = fft.rfft(concated) self.init_model = fft.rfft(concated)
if self.accumulation: if self.accumulation:
...@@ -96,17 +98,19 @@ class FrequencyAccumulator(Training): ...@@ -96,17 +98,19 @@ class FrequencyAccumulator(Training):
self.model.accumulated_changes = torch.zeros_like(self.init_model) self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model self.prev = self.init_model
else: else:
self.model.accumulated_changes += (self.init_model - self.prev) self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model self.prev = self.init_model
super().train(dataset) super().train(dataset)
with torch.no_grad(): with torch.no_grad():
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()] tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0) concated = torch.cat(tensors_to_cat, dim=0)
end_model = fft.rfft(concated) end_model = fft.rfft(concated)
change = end_model - self.init_model change = end_model - self.init_model
if self.accumulation: if self.accumulation:
change += self.model.accumulated_changes change += self.model.accumulated_changes
self.model.accumulated_gradients.append(change) self.model.accumulated_gradients.append(change)
\ No newline at end of file
...@@ -93,7 +93,9 @@ class FrequencyWaveletAccumulator(Training): ...@@ -93,7 +93,9 @@ class FrequencyWaveletAccumulator(Training):
# this looks at the change from the last round averaging of the frequencies # this looks at the change from the last round averaging of the frequencies
with torch.no_grad(): with torch.no_grad():
self.model.accumulated_gradients = [] self.model.accumulated_gradients = []
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()] tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0) concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff) data, coeff_slices = pywt.coeffs_to_array(coeff)
...@@ -103,13 +105,15 @@ class FrequencyWaveletAccumulator(Training): ...@@ -103,13 +105,15 @@ class FrequencyWaveletAccumulator(Training):
self.model.accumulated_changes = torch.zeros_like(self.init_model) self.model.accumulated_changes = torch.zeros_like(self.init_model)
self.prev = self.init_model self.prev = self.init_model
else: else:
self.model.accumulated_changes += (self.init_model - self.prev) self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model self.prev = self.init_model
super().train(dataset) super().train(dataset)
with torch.no_grad(): with torch.no_grad():
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()] tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0) concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level) coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff) data, coeff_slices = pywt.coeffs_to_array(coeff)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment