fast_lisa_subtraction.utils.types module

class fast_lisa_subtraction.utils.types.TensorSamples(source: T | dict[NestedKey, Tensor | TensorCollection] | None = None, batch_size: Sequence[int] | Size | int | None = None, device: device | str | int | None = None, names: Sequence[str] | None = None, non_blocking: bool | None = None, lock: bool = False, **kwargs: Any)[source]

Bases: TensorDict

Wrapper class for TensorDict to better manage samples

dataframe()[source]

Returns the samples as a pandas DataFrame

dict()[source]

Returns the samples as a dict of tensors

flatten()[source]

Returns the samples as a single flattened tensor

numpy()[source]

Returns the samples as a numpy array

numpy_dict()[source]

Returns the samples as a dict of numpy arrays

tensor()[source]

Returns the samples as a single tensor

to_tensor()[source]

Returns the samples as a single tensor