models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from typing import Optional
  5. from uniface.constants import ArcFaceWeights, MobileFaceWeights, SphereFaceWeights
  6. from uniface.model_store import verify_model_weights
  7. from .base import BaseRecognizer, PreprocessConfig
  8. __all__ = ['ArcFace', 'MobileFace', 'SphereFace']
  9. class ArcFace(BaseRecognizer):
  10. """ArcFace model for robust face recognition.
  11. This class provides a concrete implementation of the BaseRecognizer,
  12. pre-configured for ArcFace models. It handles the loading of specific
  13. ArcFace weights and sets up the appropriate default preprocessing.
  14. Args:
  15. model_name (ArcFaceWeights): The specific ArcFace model variant to use.
  16. Defaults to `ArcFaceWeights.MNET`.
  17. preprocessing (Optional[PreprocessConfig]): An optional custom preprocessing
  18. configuration. If None, a default config for ArcFace is used.
  19. Example:
  20. >>> from uniface.recognition import ArcFace
  21. >>> recognizer = ArcFace()
  22. >>> # embedding = recognizer.get_normalized_embedding(image, landmarks)
  23. """
  24. def __init__(
  25. self,
  26. model_name: ArcFaceWeights = ArcFaceWeights.MNET,
  27. preprocessing: Optional[PreprocessConfig] = None,
  28. ) -> None:
  29. if preprocessing is None:
  30. preprocessing = PreprocessConfig(input_mean=127.5, input_std=127.5, input_size=(112, 112))
  31. model_path = verify_model_weights(model_name)
  32. super().__init__(model_path=model_path, preprocessing=preprocessing)
  33. class MobileFace(BaseRecognizer):
  34. """Lightweight MobileFaceNet model for fast face recognition.
  35. This class provides a concrete implementation of the BaseRecognizer,
  36. pre-configured for MobileFaceNet models. It is optimized for speed,
  37. making it suitable for edge devices.
  38. Args:
  39. model_name (MobileFaceWeights): The specific MobileFaceNet model variant to use.
  40. Defaults to `MobileFaceWeights.MNET_V2`.
  41. preprocessing (Optional[PreprocessConfig]): An optional custom preprocessing
  42. configuration. If None, a default config for MobileFaceNet is used.
  43. Example:
  44. >>> from uniface.recognition import MobileFace
  45. >>> recognizer = MobileFace()
  46. >>> # embedding = recognizer.get_normalized_embedding(image, landmarks)
  47. """
  48. def __init__(
  49. self,
  50. model_name: MobileFaceWeights = MobileFaceWeights.MNET_V2,
  51. preprocessing: Optional[PreprocessConfig] = None,
  52. ) -> None:
  53. if preprocessing is None:
  54. preprocessing = PreprocessConfig(input_mean=127.5, input_std=127.5, input_size=(112, 112))
  55. model_path = verify_model_weights(model_name)
  56. super().__init__(model_path=model_path, preprocessing=preprocessing)
  57. class SphereFace(BaseRecognizer):
  58. """SphereFace model using angular margin for face recognition.
  59. This class provides a concrete implementation of the BaseRecognizer,
  60. pre-configured for SphereFace models, which were among the first to
  61. introduce angular margin loss functions.
  62. Args:
  63. model_name (SphereFaceWeights): The specific SphereFace model variant to use.
  64. Defaults to `SphereFaceWeights.SPHERE20`.
  65. preprocessing (Optional[PreprocessConfig]): An optional custom preprocessing
  66. configuration. If None, a default config for SphereFace is used.
  67. Example:
  68. >>> from uniface.recognition import SphereFace
  69. >>> recognizer = SphereFace()
  70. >>> # embedding = recognizer.get_normalized_embedding(image, landmarks)
  71. """
  72. def __init__(
  73. self,
  74. model_name: SphereFaceWeights = SphereFaceWeights.SPHERE20,
  75. preprocessing: Optional[PreprocessConfig] = None,
  76. ) -> None:
  77. if preprocessing is None:
  78. preprocessing = PreprocessConfig(input_mean=127.5, input_std=127.5, input_size=(112, 112))
  79. model_path = verify_model_weights(model_name)
  80. super().__init__(model_path=model_path, preprocessing=preprocessing)