base.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from abc import ABC, abstractmethod
  5. from typing import Any
  6. import numpy as np
  7. class Attribute(ABC):
  8. """
  9. Abstract base class for face attribute models.
  10. This class defines the common interface that all attribute models
  11. (e.g., age-gender, emotion) must implement. It ensures a consistent API
  12. across different attribute prediction modules in the library, making them
  13. interchangeable and easy to use.
  14. """
  15. @abstractmethod
  16. def _initialize_model(self) -> None:
  17. """
  18. Initializes the underlying model for inference.
  19. This method should handle loading model weights, creating the
  20. inference session (e.g., ONNX Runtime, PyTorch), and any necessary
  21. warm-up procedures to prepare the model for prediction.
  22. """
  23. raise NotImplementedError('Subclasses must implement the _initialize_model method.')
  24. @abstractmethod
  25. def preprocess(self, image: np.ndarray, *args: Any) -> Any:
  26. """
  27. Preprocesses the input data for the model.
  28. This method should take a raw image and any other necessary data
  29. (like bounding boxes or landmarks) and convert it into the format
  30. expected by the model's inference engine (e.g., a blob or tensor).
  31. Args:
  32. image (np.ndarray): The input image containing the face, typically
  33. in BGR format.
  34. *args: Additional arguments required for preprocessing, such as
  35. bounding boxes or facial landmarks.
  36. Returns:
  37. The preprocessed data ready for model inference.
  38. """
  39. raise NotImplementedError('Subclasses must implement the preprocess method.')
  40. @abstractmethod
  41. def postprocess(self, prediction: Any) -> Any:
  42. """
  43. Postprocesses the raw model output into a human-readable format.
  44. This method takes the raw output from the model's inference and
  45. converts it into a meaningful result, such as an age value, a gender
  46. label, or an emotion category.
  47. Args:
  48. prediction (Any): The raw output from the model's inference.
  49. Returns:
  50. The final, processed attributes.
  51. """
  52. raise NotImplementedError('Subclasses must implement the postprocess method.')
  53. @abstractmethod
  54. def predict(self, image: np.ndarray, *args: Any) -> Any:
  55. """
  56. Performs end-to-end attribute prediction on a given image.
  57. This method orchestrates the full pipeline: it calls the preprocess,
  58. inference, and postprocess steps to return the final, user-friendly
  59. attribute prediction.
  60. Args:
  61. image (np.ndarray): The input image containing the face.
  62. *args: Additional data required for prediction, such as a bounding
  63. box or landmarks.
  64. Returns:
  65. The final predicted attributes.
  66. """
  67. raise NotImplementedError('Subclasses must implement the predict method.')
  68. def __call__(self, *args, **kwargs) -> Any:
  69. """
  70. Provides a convenient, callable shortcut for the `predict` method.
  71. """
  72. return self.predict(*args, **kwargs)