Source code for fast_lisa_subtraction.utils.types

import torch
import pandas as pd
from tensordict import TensorDict

#=======================================
#TensorDict Wrapper Class
#=======================================
[docs] class TensorSamples(TensorDict): """Wrapper class for TensorDict to better manage samples"""
[docs] def flatten(self): """Returns the samples as a single flattened tensor""" return self.to_tensor().flatten()
[docs] def to_tensor(self): """Returns the samples as a single tensor""" return torch.stack([self[key] for key in self.keys()], dim=-1)
[docs] def tensor(self): """Returns the samples as a single tensor""" return self.to_tensor()
[docs] def numpy(self): """Returns the samples as a numpy array""" return self.tensor().cpu().numpy()
[docs] def dict(self): """Returns the samples as a dict of tensors""" return self.to_dict()
[docs] def numpy_dict(self): """Returns the samples as a dict of numpy arrays""" return {key: self[key].cpu().numpy() for key in self.keys()}
[docs] def dataframe(self): """Returns the samples as a pandas DataFrame""" return pd.DataFrame(self.numpy_dict())