Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions pina/_src/condition/condition_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pina._src.condition.condition_interface import ConditionInterface
from pina._src.core.graph import LabelBatch
from pina._src.core.label_tensor import LabelTensor
from pina._src.data.dummy_dataloader import DummyDataloader


class ConditionBase(ConditionInterface):
Expand All @@ -33,6 +34,7 @@ def __init__(self, **kwargs):
"""
super().__init__()
self.data = self.store_data(**kwargs)
self.has_custom_dataloader_fn = False

@property
def problem(self):
Expand Down Expand Up @@ -85,7 +87,8 @@ def automatic_batching_collate_fn(cls, batch):
if not batch:
return {}
instance_class = batch[0].__class__
return instance_class.create_batch(batch)
batch = instance_class.create_batch(batch)
return batch

@staticmethod
def collate_fn(batch, condition):
Expand All @@ -103,7 +106,11 @@ def collate_fn(batch, condition):
return data

def create_dataloader(
self, dataset, batch_size, shuffle, automatic_batching
self,
dataset,
batch_size,
automatic_batching,
**kwargs,
):
"""
Create a DataLoader for the condition.
Expand All @@ -114,14 +121,28 @@ def create_dataloader(
:rtype: torch.utils.data.DataLoader
"""
if batch_size == len(dataset):
pass # will be updated in the near future
return DummyDataloader(dataset)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=(
partial(self.collate_fn, condition=self)
if not automatic_batching
else self.automatic_batching_collate_fn
),
batch_size=batch_size,
**kwargs,
)

def switch_dataloader_fn(self, create_dataloader_fn):
"""
Decorator to switch the dataloader function for a condition.

:param create_dataloader_fn: The new dataloader function to use.
:type create_dataloader_fn: function
:return: The decorated function with the new dataloader function.
:rtype: function
"""
# Replace the create_dataloader method of the ConditionBase class with
# the new function
self.has_custom_dataloader_fn = True
self.create_dataloader = create_dataloader_fn
2 changes: 1 addition & 1 deletion pina/_src/condition/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def create_batch(items):
if isinstance(sample, LabelTensor)
else torch.stack
)
batch_data[k] = batch_fn(vals, dim=0)
batch_data[k] = batch_fn(vals)
else:
batch_data[k] = sample
return batch_data
Expand Down
33 changes: 16 additions & 17 deletions pina/_src/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
test_size=0.0,
val_size=0.0,
compile=None,
repeat=None,
batching_mode="common_batch_size",
automatic_batching=None,
num_workers=None,
pin_memory=None,
Expand All @@ -61,9 +61,9 @@ def __init__(
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled. Not
supported for python version greater or equal than 3.14.
:param bool repeat: Whether to repeat the dataset data in each
condition during training. For further details, see the
:class:`~pina.data.data_module.PinaDataModule` class. Default is
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``. Default is ``"common_batch_size"``.
``False``.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed, otherwise the items are retrieved from the dataset
Expand All @@ -87,7 +87,7 @@ def __init__(
train_size=train_size,
test_size=test_size,
val_size=val_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
compile=compile,
)
Expand Down Expand Up @@ -127,8 +127,6 @@ def __init__(
UserWarning,
)

repeat = repeat if repeat is not None else False

automatic_batching = (
automatic_batching if automatic_batching is not None else False
)
Expand All @@ -144,7 +142,7 @@ def __init__(
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
pin_memory=pin_memory,
num_workers=num_workers,
Expand Down Expand Up @@ -182,7 +180,7 @@ def _create_datamodule(
test_size,
val_size,
batch_size,
repeat,
batching_mode,
automatic_batching,
pin_memory,
num_workers,
Expand All @@ -201,8 +199,9 @@ def _create_datamodule(
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
Expand Down Expand Up @@ -232,7 +231,7 @@ def _create_datamodule(
test_size=test_size,
val_size=val_size,
batch_size=batch_size,
repeat=repeat,
batching_mode=batching_mode,
automatic_batching=automatic_batching,
num_workers=num_workers,
pin_memory=pin_memory,
Expand Down Expand Up @@ -284,7 +283,7 @@ def _check_input_consistency(
train_size,
test_size,
val_size,
repeat,
batching_mode,
automatic_batching,
compile,
):
Expand All @@ -298,8 +297,9 @@ def _check_input_consistency(
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool repeat: Whether to repeat the dataset data in each
condition during training.
:param str batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
Expand All @@ -309,8 +309,7 @@ def _check_input_consistency(
check_consistency(train_size, float)
check_consistency(test_size, float)
check_consistency(val_size, float)
if repeat is not None:
check_consistency(repeat, bool)
check_consistency(batching_mode, str)
if automatic_batching is not None:
check_consistency(automatic_batching, bool)
if compile is not None:
Expand Down
61 changes: 61 additions & 0 deletions pina/_src/data/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Aggregator for multiple dataloaders.
"""


class _Aggregator:
"""
The class :class:`_Aggregator` is responsible for aggregating multiple
dataloaders into a single iterable object. It supports different batching
modes to accommodate various training requirements.
"""

def __init__(self, dataloaders, batching_mode):
"""
Initialization of the :class:`_Aggregator` class.

:param dataloaders: A dictionary mapping condition names to their
respective dataloaders.
:type dataloaders: dict[str, DataLoader]
:param batching_mode: The batching mode to use. Options are
``"common_batch_size"``, ``"proportional"``, and
``"separate_conditions"``.
:type batching_mode: str
"""
self.dataloaders = dataloaders
self.batching_mode = batching_mode

def __len__(self):
"""
Return the length of the aggregated dataloader.

:return: The length of the aggregated dataloader.
:rtype: int
"""
if self.batching_mode == "separate_conditions":
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())

def __iter__(self):
"""
Return an iterator over the aggregated dataloader.

:return: An iterator over the aggregated dataloader.
:rtype: iterator
"""
if self.batching_mode == "separate_conditions":
# TODO: implement separate_conditions batching mode
raise NotImplementedError(
"Batching mode 'separate_conditions' is not implemented yet."
)

iterators = {name: iter(dl) for name, dl in self.dataloaders.items()}
for _ in range(len(self)):
batch = {}
for name, it in iterators.items():
try:
batch[name] = next(it)
except StopIteration:
iterators[name] = iter(self.dataloaders[name])
batch[name] = next(iterators[name])
yield batch
Loading