From 65f1b43a74941c77d030530ddb639e6af2d9c65f Mon Sep 17 00:00:00 2001 From: Milos Vujasinovic <milos.vujasinovic@epfl.ch> Date: Wed, 16 Nov 2022 17:58:40 +0100 Subject: [PATCH] Added support for native seed to random states --- src/decentralizepy/random.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/decentralizepy/random.py b/src/decentralizepy/random.py index c1d5a03..e515cf8 100644 --- a/src/decentralizepy/random.py +++ b/src/decentralizepy/random.py @@ -1,4 +1,6 @@ +import random import contextlib + import torch import numpy as np @@ -14,13 +16,16 @@ def temp_seed(seed): on CPU regardless if CUDA is used for other things. """ + random_state = random.getstate() np_old_state = np.random.get_state() torch_old_state = torch.random.get_rng_state() - torch.random.manual_seed(seed) + random.seed(seed) np.random.seed(seed) + torch.random.manual_seed(seed) try: yield finally: + random.setstate(random_state) np.random.set_state(np_old_state) torch.random.set_rng_state(torch_old_state) @@ -33,8 +38,17 @@ class RandomState: """ def __init__(self, seed): with temp_seed(seed): - self.__np_state = np.random.get_state() - self.__torch_state = torch.random.get_rng_state() + self.__refresh_states() + + def __refresh_states(self): + self.__random_state = random.getstate() + self.__np_state = np.random.get_state() + self.__torch_state = torch.random.get_rng_state() + + def __set_states(self): + random.setstate(self.__random_state) + np.random.set_state(self.__np_state) + torch.random.set_rng_state(self.__torch_state) @contextlib.contextmanager def activate(self): @@ -44,14 +58,14 @@ class RandomState: is finished """ + random_state = random.getstate() np_old_state = np.random.get_state() torch_old_state = torch.random.get_rng_state() - np.random.set_state(self.__np_state) - torch.random.set_rng_state(self.__torch_state) + self.__set_states() try: yield finally: - self.__np_state = np.random.get_state() - self.__torch_state = torch.random.get_rng_state() + self.__refresh_states() + random.setstate(random_state) np.random.set_state(np_old_state) torch.random.set_rng_state(torch_old_state) \ No newline at end of file -- GitLab