Skip to content
Snippets Groups Projects
Commit 65f1b43a authored by Milos Vujasinovic's avatar Milos Vujasinovic
Browse files

Added support for native seed to random states

parent 1bac78e3
No related branches found
No related tags found
No related merge requests found
import random
import contextlib
import torch
import numpy as np
......@@ -14,13 +16,16 @@ def temp_seed(seed):
on CPU regardless if CUDA is used for other things.
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
torch.random.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
try:
yield
finally:
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
......@@ -33,8 +38,17 @@ class RandomState:
"""
def __init__(self, seed):
with temp_seed(seed):
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
self.__refresh_states()
def __refresh_states(self):
self.__random_state = random.getstate()
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
def __set_states(self):
random.setstate(self.__random_state)
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
@contextlib.contextmanager
def activate(self):
......@@ -44,14 +58,14 @@ class RandomState:
is finished
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
self.__set_states()
try:
yield
finally:
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
self.__refresh_states()
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment