From 65f1b43a74941c77d030530ddb639e6af2d9c65f Mon Sep 17 00:00:00 2001
From: Milos Vujasinovic <milos.vujasinovic@epfl.ch>
Date: Wed, 16 Nov 2022 17:58:40 +0100
Subject: [PATCH] Added support for native seed to random states

---
 src/decentralizepy/random.py | 28 +++++++++++++++++++++-------
 1 file changed, 21 insertions(+), 7 deletions(-)

diff --git a/src/decentralizepy/random.py b/src/decentralizepy/random.py
index c1d5a03..e515cf8 100644
--- a/src/decentralizepy/random.py
+++ b/src/decentralizepy/random.py
@@ -1,4 +1,6 @@
+import random
 import contextlib
+
 import torch
 import numpy as np
 
@@ -14,13 +16,16 @@ def temp_seed(seed):
         on CPU regardless if CUDA is used for other things.
         
     """
+    random_state    = random.getstate()
     np_old_state    = np.random.get_state()
     torch_old_state = torch.random.get_rng_state()
-    torch.random.manual_seed(seed)
+    random.seed(seed)
     np.random.seed(seed)
+    torch.random.manual_seed(seed)
     try:
         yield
     finally:
+        random.setstate(random_state)
         np.random.set_state(np_old_state)
         torch.random.set_rng_state(torch_old_state)
 
@@ -33,8 +38,17 @@ class RandomState:
     """
     def __init__(self, seed):
         with temp_seed(seed):
-          self.__np_state    = np.random.get_state()
-          self.__torch_state = torch.random.get_rng_state()
+            self.__refresh_states()
+
+    def __refresh_states(self):
+        self.__random_state = random.getstate()
+        self.__np_state     = np.random.get_state()
+        self.__torch_state  = torch.random.get_rng_state()
+
+    def __set_states(self):
+        random.setstate(self.__random_state)
+        np.random.set_state(self.__np_state)
+        torch.random.set_rng_state(self.__torch_state)
 
     @contextlib.contextmanager
     def activate(self):
@@ -44,14 +58,14 @@ class RandomState:
         is finished
 
         """
+        random_state    = random.getstate()
         np_old_state    = np.random.get_state()
         torch_old_state = torch.random.get_rng_state()
-        np.random.set_state(self.__np_state)
-        torch.random.set_rng_state(self.__torch_state)
+        self.__set_states()
         try:
             yield
         finally:
-            self.__np_state    = np.random.get_state()
-            self.__torch_state = torch.random.get_rng_state()
+            self.__refresh_states()
+            random.setstate(random_state)
             np.random.set_state(np_old_state)
             torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
-- 
GitLab