age_gender.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from typing import List, Tuple, Union
  5. import cv2
  6. import numpy as np
  7. from uniface.attribute.base import Attribute
  8. from uniface.constants import AgeGenderWeights
  9. from uniface.face_utils import bbox_center_alignment
  10. from uniface.log import Logger
  11. from uniface.model_store import verify_model_weights
  12. from uniface.onnx_utils import create_onnx_session
  13. __all__ = ['AgeGender']
  14. class AgeGender(Attribute):
  15. """
  16. Age and gender prediction model using ONNX Runtime.
  17. This class inherits from the base `Attribute` class and implements the
  18. functionality for predicting age (in years) and gender ID (0 for Female,
  19. 1 for Male) from a face image. It requires a bounding box to locate the face.
  20. """
  21. def __init__(self, model_name: AgeGenderWeights = AgeGenderWeights.DEFAULT) -> None:
  22. """
  23. Initializes the AgeGender prediction model.
  24. Args:
  25. model_name (AgeGenderWeights): The enum specifying the model weights
  26. to load.
  27. """
  28. Logger.info(f'Initializing AgeGender with model={model_name.name}')
  29. self.model_path = verify_model_weights(model_name)
  30. self._initialize_model()
  31. def _initialize_model(self) -> None:
  32. """
  33. Initializes the ONNX model and creates an inference session.
  34. """
  35. try:
  36. self.session = create_onnx_session(self.model_path)
  37. # Get model input details from the loaded model
  38. input_meta = self.session.get_inputs()[0]
  39. self.input_name = input_meta.name
  40. self.input_size = tuple(input_meta.shape[2:4]) # (height, width)
  41. self.output_names = [output.name for output in self.session.get_outputs()]
  42. Logger.info(f'Successfully initialized AgeGender model with input size {self.input_size}')
  43. except Exception as e:
  44. Logger.error(
  45. f"Failed to load AgeGender model from '{self.model_path}'",
  46. exc_info=True,
  47. )
  48. raise RuntimeError(f'Failed to initialize AgeGender model: {e}') from e
  49. def preprocess(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> np.ndarray:
  50. """
  51. Aligns the face based on the bounding box and preprocesses it for inference.
  52. Args:
  53. image (np.ndarray): The full input image in BGR format.
  54. bbox (Union[List, np.ndarray]): The face bounding box coordinates [x1, y1, x2, y2].
  55. Returns:
  56. np.ndarray: The preprocessed image blob ready for inference.
  57. """
  58. bbox = np.asarray(bbox)
  59. width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]
  60. center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
  61. scale = self.input_size[1] / (max(width, height) * 1.5)
  62. # **Rotation parameter restored here**
  63. rotation = 0.0
  64. aligned_face, _ = bbox_center_alignment(image, center, self.input_size[1], scale, rotation)
  65. blob = cv2.dnn.blobFromImage(
  66. aligned_face,
  67. scalefactor=1.0,
  68. size=self.input_size[::-1],
  69. mean=(0.0, 0.0, 0.0),
  70. swapRB=True,
  71. )
  72. return blob
  73. def postprocess(self, prediction: np.ndarray) -> Tuple[int, int]:
  74. """
  75. Processes the raw model output to extract gender and age.
  76. Args:
  77. prediction (np.ndarray): The raw output from the model inference.
  78. Returns:
  79. Tuple[int, int]: A tuple containing the predicted gender ID (0 for Female, 1 for Male)
  80. and age (in years).
  81. """
  82. # First two values are gender logits
  83. gender_id = int(np.argmax(prediction[:2]))
  84. # Third value is normalized age, scaled by 100
  85. age = int(np.round(prediction[2] * 100))
  86. return gender_id, age
  87. def predict(self, image: np.ndarray, bbox: Union[List, np.ndarray]) -> Tuple[int, int]:
  88. """
  89. Predicts age and gender for a single face specified by a bounding box.
  90. Args:
  91. image (np.ndarray): The full input image in BGR format.
  92. bbox (Union[List, np.ndarray]): The face bounding box coordinates [x1, y1, x2, y2].
  93. Returns:
  94. Tuple[int, int]: A tuple containing the predicted gender ID (0 for Female, 1 for Male) and age.
  95. """
  96. face_blob = self.preprocess(image, bbox)
  97. prediction = self.session.run(self.output_names, {self.input_name: face_blob})[0][0]
  98. gender_id, age = self.postprocess(prediction)
  99. return gender_id, age
  100. # TODO: below is only for testing, remove it later
  101. if __name__ == '__main__':
  102. # To run this script, you need to have uniface.detection installed
  103. # or available in your path.
  104. from uniface.constants import RetinaFaceWeights
  105. from uniface.detection import create_detector
  106. print('Initializing models for live inference...')
  107. # 1. Initialize the face detector
  108. # Using a smaller model for faster real-time performance
  109. detector = create_detector(model_name=RetinaFaceWeights.MNET_V2)
  110. # 2. Initialize the attribute predictor
  111. age_gender_predictor = AgeGender()
  112. # 3. Start webcam capture
  113. cap = cv2.VideoCapture(0)
  114. if not cap.isOpened():
  115. print('Error: Could not open webcam.')
  116. exit()
  117. print("Starting webcam feed. Press 'q' to quit.")
  118. while True:
  119. ret, frame = cap.read()
  120. if not ret:
  121. print('Error: Failed to capture frame.')
  122. break
  123. # Detect faces in the current frame
  124. detections = detector.detect(frame)
  125. # For each detected face, predict age and gender
  126. for detection in detections:
  127. box = detection['bbox']
  128. x1, y1, x2, y2 = map(int, box)
  129. # Predict attributes
  130. gender_id, age = age_gender_predictor.predict(frame, box)
  131. gender_str = 'Female' if gender_id == 0 else 'Male'
  132. # Prepare text and draw on the frame
  133. label = f'{gender_str}, {age}'
  134. cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
  135. cv2.putText(
  136. frame,
  137. label,
  138. (x1, y1 - 10),
  139. cv2.FONT_HERSHEY_SIMPLEX,
  140. 0.8,
  141. (0, 255, 0),
  142. 2,
  143. )
  144. # Display the resulting frame
  145. cv2.imshow("Age and Gender Inference (Press 'q' to quit)", frame)
  146. # Break the loop if 'q' is pressed
  147. if cv2.waitKey(1) & 0xFF == ord('q'):
  148. break
  149. # Release resources
  150. cap.release()
  151. cv2.destroyAllWindows()
  152. print('Inference stopped.')