Source code for dl_data_pipeline.validator.shape_validator

from typing import Any
from .base_validator import Validator
from .validation_error import ValidationError

[docs] class ShapeValidator(Validator): """ Validator class that checks if the shape of the data matches the expected shape. This class extends the `Validator` base class to ensure that the shape of the input data matches a specified shape. It uses a specified attribute getter method to retrieve the shape of the data. Args: accepted_shape (Any): The shape that the data should match. This can be any value that represents the expected shape (e.g., a tuple for multidimensional arrays). shape_getter (str, optional): The name of the attribute to be used for retrieving the shape of the data. Defaults to "shape". This should be the name of an attribute that returns the shape of the data (e.g., 'shape' for NumPy arrays). Raises: ValidationError: If the data does not have the specified shape or if the shape attribute is not present. Examples: >>> validator = ShapeValidator((100, 200)) >>> data = np.zeros((100, 200)) >>> validator.validate(data) # This will pass because the shape matches. >>> validator = ShapeValidator((100, 200)) >>> data = np.zeros((100, 300)) >>> validator.validate(data) # This will raise a ValidationError because the shape does not match. """ def __init__(self, accepted_shape: Any, shape_getter: str = "shape") -> None: self.__shape = accepted_shape self.__getter = shape_getter
[docs] def validate(self, data: Any) -> None: if not hasattr(data, self.__getter): raise ValidationError(f"data must have an attribute {self.__getter}.") shape = getattr(data, self.__getter) if shape != self.__shape: raise ValidationError(f"Invalid shape, expected : {self.__shape}, received : {shape}.")