Skip to content
Snippets Groups Projects
Commit 1bac78e3 authored by Milos Vujasinovic's avatar Milos Vujasinovic
Browse files

Added utilities for handling random states

parent 0c1ab13e
No related branches found
No related tags found
No related merge requests found
import contextlib
import torch
import numpy as np
@contextlib.contextmanager
def temp_seed(seed):
"""
Creates a context with seeds set to given value. Returns to the
previous seed afterwards.
Note: Based on torch implementation there might be issues with CUDA
causing troubles with the correctness of this function. Function
torch.rand() work fine from testing as their results are generated
on CPU regardless if CUDA is used for other things.
"""
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
torch.random.manual_seed(seed)
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
class RandomState:
"""
Creates a state that affects random number generation on
torch and numpy and whose context can be activated at will
"""
def __init__(self, seed):
with temp_seed(seed):
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
@contextlib.contextmanager
def activate(self):
"""
Activates this state in the given context for torch and
numpy. The previous state is restored when the context
is finished
"""
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
try:
yield
finally:
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment