File tree Expand file tree Collapse file tree 3 files changed +35
-2
lines changed
Expand file tree Collapse file tree 3 files changed +35
-2
lines changed Original file line number Diff line number Diff line change @@ -81,6 +81,38 @@ def strtobool(val: Any) -> bool:
8181
8282BATCHED_PIPE_TIMEOUT = float (os .environ .get ("BATCHED_PIPE_TIMEOUT" , "10000.0" ))
8383
84+ _TORCH_DTYPES = (
85+ torch .bfloat16 ,
86+ torch .bool ,
87+ torch .complex128 ,
88+ torch .complex32 ,
89+ torch .complex64 ,
90+ torch .float16 ,
91+ torch .float32 ,
92+ torch .float64 ,
93+ torch .int16 ,
94+ torch .int32 ,
95+ torch .int64 ,
96+ torch .int8 ,
97+ torch .qint32 ,
98+ torch .qint8 ,
99+ torch .quint4x2 ,
100+ torch .quint8 ,
101+ torch .uint8 ,
102+ )
103+ if hasattr (torch , "uint16" ):
104+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint16 ,)
105+ if hasattr (torch , "uint32" ):
106+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint32 ,)
107+ if hasattr (torch , "uint64" ):
108+ _TORCH_DTYPES = _TORCH_DTYPES + (torch .uint64 ,)
109+ _STR_DTYPE_TO_DTYPE = {str (dtype ): dtype for dtype in _TORCH_DTYPES }
110+ _STRDTYPE2DTYPE = _STR_DTYPE_TO_DTYPE
111+ _DTYPE_TO_STR_DTYPE = {
112+ dtype : str_dtype for str_dtype , dtype in _STR_DTYPE_TO_DTYPE .items ()
113+ }
114+ _DTYPE2STRDTYPE = _STR_DTYPE_TO_DTYPE
115+
84116
85117class timeit :
86118 """A dirty but easy to use decorator for profiling code."""
Original file line number Diff line number Diff line change 1818 TensorDict ,
1919)
2020from tensordict .memmap import MemoryMappedTensor
21- from tensordict . utils import _STRDTYPE2DTYPE
21+ from torchrl . _utils import _STRDTYPE2DTYPE
2222
2323from torchrl .data .replay_buffers .utils import (
2424 _save_pytree ,
Original file line number Diff line number Diff line change 1616import numpy as np
1717import torch
1818from tensordict import is_tensor_collection , MemoryMappedTensor , TensorDictBase
19- from tensordict .utils import _STRDTYPE2DTYPE , expand_as_right , is_tensorclass
19+ from tensordict .utils import expand_as_right , is_tensorclass
2020from torch import multiprocessing as mp
21+ from torchrl ._utils import _STRDTYPE2DTYPE
2122
2223try :
2324 from torch .utils ._pytree import tree_leaves
You can’t perform that action at this time.
0 commit comments