Module xingyun.random.my_random
Expand source code
import time
import copy
import random
from typing import Any
from xingyun.universal.import_module import my_import_module
from .set_random_seed import RandomAllowedModule
def get_random_state(module: RandomAllowedModule) -> Any:
if module == "random":
return random.getstate()
if module == "numpy":
np = my_import_module("numpy")
return np.random.get_state()
if module == "torch":
torch = my_import_module("torch")
cuda = my_import_module("torch.cuda")
return {
"torch": torch.random.get_rng_state(),
"cuda" : cuda .random.get_rng_state(),
}
def set_random_state(state: Any, module: RandomAllowedModule) -> bool:
flag = True
if module == "random":
try:
random.setstate(state)
except:
flag = False
if module == "numpy":
try:
np = my_import_module("numpy")
np.random.set_state(state)
except:
flag = False
if module == "torch":
try:
torch = my_import_module("torch")
cuda = my_import_module("torch.cuda")
torch.random.set_rng_state(state["torch"])
cuda.random.set_rng_state(state["cuda"])
except:
flag = False
if not flag:
raise RuntimeError(f"set random state of module {module} bad.")
return flag
class MyRandom:
def __init__(self, random_seed: int | None = None, modules: list[RandomAllowedModule] = ["random" , "torch" , "numpy"]):
'''This class create a temporary environment, inside which the random seed
is set to a given value while not affecting the global random seed.
Notice that, to make this class work, the global random seed must be also managed by `xingyun`.
'''
if random_seed is None:
random_seed = int( time.time() )
self.random_seed = random_seed
self.modules = modules
self.entering_state = {}
def __enter__(self):
for m in self.modules:
self.entering_state[m] = get_random_state(m)
def __exit__(self, *args, **kwargs):
for m in self.modules:
set_random_state(self.entering_state[m],m)
Functions
def get_random_state(module: Union[Literal['torch'], Literal['numpy'], Literal['random']]) ‑> Any-
Expand source code
def get_random_state(module: RandomAllowedModule) -> Any: if module == "random": return random.getstate() if module == "numpy": np = my_import_module("numpy") return np.random.get_state() if module == "torch": torch = my_import_module("torch") cuda = my_import_module("torch.cuda") return { "torch": torch.random.get_rng_state(), "cuda" : cuda .random.get_rng_state(), } def set_random_state(state: Any, module: Union[Literal['torch'], Literal['numpy'], Literal['random']]) ‑> bool-
Expand source code
def set_random_state(state: Any, module: RandomAllowedModule) -> bool: flag = True if module == "random": try: random.setstate(state) except: flag = False if module == "numpy": try: np = my_import_module("numpy") np.random.set_state(state) except: flag = False if module == "torch": try: torch = my_import_module("torch") cuda = my_import_module("torch.cuda") torch.random.set_rng_state(state["torch"]) cuda.random.set_rng_state(state["cuda"]) except: flag = False if not flag: raise RuntimeError(f"set random state of module {module} bad.") return flag
Classes
class MyRandom (random_seed: int | None = None, modules: list[typing.Union[typing.Literal['torch'], typing.Literal['numpy'], typing.Literal['random']]] = ['random', 'torch', 'numpy'])-
This class create a temporary environment, inside which the random seed is set to a given value while not affecting the global random seed.
Notice that, to make this class work, the global random seed must be also managed by
xingyun.Expand source code
class MyRandom: def __init__(self, random_seed: int | None = None, modules: list[RandomAllowedModule] = ["random" , "torch" , "numpy"]): '''This class create a temporary environment, inside which the random seed is set to a given value while not affecting the global random seed. Notice that, to make this class work, the global random seed must be also managed by `xingyun`. ''' if random_seed is None: random_seed = int( time.time() ) self.random_seed = random_seed self.modules = modules self.entering_state = {} def __enter__(self): for m in self.modules: self.entering_state[m] = get_random_state(m) def __exit__(self, *args, **kwargs): for m in self.modules: set_random_state(self.entering_state[m],m)