TorchCheckpointIO¶
- class lightning.fabric.plugins.io.torch_io.TorchCheckpointIO[source]¶
Bases:
CheckpointIO
CheckpointIO that utilizes
torch.save()
andtorch.load()
to save and load checkpoints respectively, common for most use cases.Warning
This is an experimental feature.
- load_checkpoint(path, map_location=<function TorchCheckpointIO.<lambda>>, weights_only=None)[source]¶
Loads checkpoint using
torch.load()
, with additional handling forfsspec
remote loading of files.- Parameters:
map_location¶ (
Optional
[Callable
]) – a function,torch.device
, string or a dict specifying how to remap storage locations.weights_only¶ (
Optional
[bool
]) – Defaults toNone
. IfTrue
, restricts loading tostate_dicts
of plaintorch.Tensor
and other primitive types. If loading a checkpoint from a trusted source that contains annn.Module
, useweights_only=False
. If loading checkpoint from an untrusted source, we recommend usingweights_only=True
. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.
- Return type:
Returns: The loaded checkpoint.
- Raises:
FileNotFoundError – If
path
is not found by thefsspec
filesystem