diff --git a/src/decentralizepy/random.py b/src/decentralizepy/random.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d5a036851b91a1837334cfd66e53aa245fe1cf --- /dev/null +++ b/src/decentralizepy/random.py @@ -0,0 +1,57 @@ +import contextlib +import torch +import numpy as np + +@contextlib.contextmanager +def temp_seed(seed): + """ + Creates a context with seeds set to given value. Returns to the + previous seed afterwards. + + Note: Based on torch implementation there might be issues with CUDA + causing troubles with the correctness of this function. Function + torch.rand() work fine from testing as their results are generated + on CPU regardless if CUDA is used for other things. + + """ + np_old_state = np.random.get_state() + torch_old_state = torch.random.get_rng_state() + torch.random.manual_seed(seed) + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(np_old_state) + torch.random.set_rng_state(torch_old_state) + + +class RandomState: + """ + Creates a state that affects random number generation on + torch and numpy and whose context can be activated at will + + """ + def __init__(self, seed): + with temp_seed(seed): + self.__np_state = np.random.get_state() + self.__torch_state = torch.random.get_rng_state() + + @contextlib.contextmanager + def activate(self): + """ + Activates this state in the given context for torch and + numpy. The previous state is restored when the context + is finished + + """ + np_old_state = np.random.get_state() + torch_old_state = torch.random.get_rng_state() + np.random.set_state(self.__np_state) + torch.random.set_rng_state(self.__torch_state) + try: + yield + finally: + self.__np_state = np.random.get_state() + self.__torch_state = torch.random.get_rng_state() + np.random.set_state(np_old_state) + torch.random.set_rng_state(torch_old_state) \ No newline at end of file