scrfd.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from typing import Any, Dict, List, Literal, Tuple
  5. import cv2
  6. import numpy as np
  7. from uniface.common import distance2bbox, distance2kps, non_max_suppression, resize_image
  8. from uniface.constants import SCRFDWeights
  9. from uniface.log import Logger
  10. from uniface.model_store import verify_model_weights
  11. from uniface.onnx_utils import create_onnx_session
  12. from .base import BaseDetector
  13. __all__ = ['SCRFD']
  14. class SCRFD(BaseDetector):
  15. """
  16. Face detector based on the SCRFD architecture.
  17. Title: "Sample and Computation Redistribution for Efficient Face Detection"
  18. Paper: https://arxiv.org/abs/2105.04714
  19. Args:
  20. **kwargs: Keyword arguments passed to BaseDetector and SCRFD. Supported keys include:
  21. model_name (SCRFDWeights, optional): Predefined model enum (e.g., `SCRFD_10G_KPS`).
  22. Specifies the SCRFD variant to load. Defaults to SCRFD_10G_KPS.
  23. conf_thresh (float, optional): Confidence threshold for filtering detections. Defaults to 0.5.
  24. nms_thresh (float, optional): Non-Maximum Suppression threshold. Defaults to 0.4.
  25. input_size (Tuple[int, int], optional): Input image size (width, height). Defaults to (640, 640).
  26. Attributes:
  27. conf_thresh (float): Threshold used to filter low-confidence detections.
  28. nms_thresh (float): Threshold used during NMS to suppress overlapping boxes.
  29. input_size (Tuple[int, int]): Image size to which inputs are resized before inference.
  30. _fmc (int): Number of feature map levels used in the model.
  31. _feat_stride_fpn (List[int]): Feature map strides corresponding to each detection level.
  32. _num_anchors (int): Number of anchors per feature location.
  33. _center_cache (Dict): Cached anchor centers for efficient forward passes.
  34. _model_path (str): Absolute path to the downloaded/verified model weights.
  35. Raises:
  36. ValueError: If the model weights are invalid or not found.
  37. RuntimeError: If the ONNX model fails to load or initialize.
  38. """
  39. def __init__(self, **kwargs) -> None:
  40. super().__init__(**kwargs)
  41. self._supports_landmarks = True # SCRFD supports landmarks
  42. model_name = kwargs.get('model_name', SCRFDWeights.SCRFD_10G_KPS)
  43. conf_thresh = kwargs.get('conf_thresh', 0.5)
  44. nms_thresh = kwargs.get('nms_thresh', 0.4)
  45. input_size = kwargs.get('input_size', (640, 640))
  46. self.conf_thresh = conf_thresh
  47. self.nms_thresh = nms_thresh
  48. self.input_size = input_size
  49. # ------- SCRFD model params ------
  50. self._fmc = 3
  51. self._feat_stride_fpn = [8, 16, 32]
  52. self._num_anchors = 2
  53. self._center_cache = {}
  54. # ---------------------------------
  55. Logger.info(
  56. f'Initializing SCRFD with model={model_name}, conf_thresh={conf_thresh}, nms_thresh={nms_thresh}, '
  57. f'input_size={input_size}'
  58. )
  59. # Get path to model weights
  60. self._model_path = verify_model_weights(model_name)
  61. Logger.info(f'Verified model weights located at: {self._model_path}')
  62. # Initialize model
  63. self._initialize_model(self._model_path)
  64. def _initialize_model(self, model_path: str) -> None:
  65. """
  66. Initializes an ONNX model session from the given path.
  67. Args:
  68. model_path (str): The file path to the ONNX model.
  69. Raises:
  70. RuntimeError: If the model fails to load, logs an error and raises an exception.
  71. """
  72. try:
  73. self.session = create_onnx_session(model_path)
  74. self.input_names = self.session.get_inputs()[0].name
  75. self.output_names = [x.name for x in self.session.get_outputs()]
  76. Logger.info(f'Successfully initialized the model from {model_path}')
  77. except Exception as e:
  78. Logger.error(f"Failed to load model from '{model_path}': {e}", exc_info=True)
  79. raise RuntimeError(f"Failed to initialize model session for '{model_path}'") from e
  80. def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
  81. """Preprocess image for inference.
  82. Args:
  83. image (np.ndarray): Input image
  84. Returns:
  85. Tuple[np.ndarray, Tuple[int, int]]: Preprocessed blob and input size
  86. """
  87. image = image.astype(np.float32)
  88. image = (image - 127.5) / 127.5
  89. image = image.transpose(2, 0, 1) # HWC to CHW
  90. image = np.expand_dims(image, axis=0)
  91. return image
  92. def inference(self, input_tensor: np.ndarray) -> List[np.ndarray]:
  93. """Perform model inference on the preprocessed image tensor.
  94. Args:
  95. input_tensor (np.ndarray): Preprocessed input tensor.
  96. Returns:
  97. Tuple[np.ndarray, np.ndarray]: Raw model outputs.
  98. """
  99. return self.session.run(self.output_names, {self.input_names: input_tensor})
  100. def postprocess(self, outputs: List[np.ndarray], image_size: Tuple[int, int]):
  101. scores_list = []
  102. bboxes_list = []
  103. kpss_list = []
  104. image_size = image_size
  105. fmc = self._fmc
  106. for idx, stride in enumerate(self._feat_stride_fpn):
  107. scores = outputs[idx]
  108. bbox_preds = outputs[fmc + idx] * stride
  109. kps_preds = outputs[2 * fmc + idx] * stride
  110. # Generate anchors
  111. fm_height = image_size[0] // stride
  112. fm_width = image_size[1] // stride
  113. cache_key = (fm_height, fm_width, stride)
  114. if cache_key in self._center_cache:
  115. anchor_centers = self._center_cache[cache_key]
  116. else:
  117. y, x = np.mgrid[:fm_height, :fm_width]
  118. anchor_centers = np.stack((x, y), axis=-1).astype(np.float32)
  119. anchor_centers = (anchor_centers * stride).reshape(-1, 2)
  120. if self._num_anchors > 1:
  121. anchor_centers = np.tile(anchor_centers[:, None, :], (1, self._num_anchors, 1)).reshape(-1, 2)
  122. if len(self._center_cache) < 100:
  123. self._center_cache[cache_key] = anchor_centers
  124. pos_indices = np.where(scores >= self.conf_thresh)[0]
  125. if len(pos_indices) == 0:
  126. continue
  127. bboxes = distance2bbox(anchor_centers, bbox_preds)[pos_indices]
  128. scores_selected = scores[pos_indices]
  129. scores_list.append(scores_selected)
  130. bboxes_list.append(bboxes)
  131. landmarks = distance2kps(anchor_centers, kps_preds)
  132. landmarks = landmarks.reshape((landmarks.shape[0], -1, 2))
  133. kpss_list.append(landmarks[pos_indices])
  134. return scores_list, bboxes_list, kpss_list
  135. def detect(
  136. self,
  137. image: np.ndarray,
  138. max_num: int = 0,
  139. metric: Literal['default', 'max'] = 'max',
  140. center_weight: float = 2,
  141. ) -> List[Dict[str, Any]]:
  142. """
  143. Perform face detection on an input image and return bounding boxes and facial landmarks.
  144. Args:
  145. image (np.ndarray): Input image as a NumPy array of shape (H, W, C).
  146. max_num (int): Maximum number of detections to return. Use 0 to return all detections. Defaults to 0.
  147. metric (Literal["default", "max"]): Metric for ranking detections when `max_num` is limited.
  148. - "default": Prioritize detections closer to the image center.
  149. - "max": Prioritize detections with larger bounding box areas.
  150. center_weight (float): Weight for penalizing detections farther from the image center
  151. when using the "default" metric. Defaults to 2.0.
  152. Returns:
  153. List[Dict[str, Any]]: List of face detection dictionaries, each containing:
  154. - 'bbox' (np.ndarray): Bounding box coordinates with shape (4,) as [x1, y1, x2, y2]
  155. - 'confidence' (float): Detection confidence score (0.0 to 1.0)
  156. - 'landmarks' (np.ndarray): 5-point facial landmarks with shape (5, 2)
  157. Example:
  158. >>> faces = detector.detect(image)
  159. >>> for face in faces:
  160. ... bbox = face['bbox'] # np.ndarray with shape (4,)
  161. ... confidence = face['confidence'] # float
  162. ... landmarks = face['landmarks'] # np.ndarray with shape (5, 2)
  163. ... # Can pass landmarks directly to recognition
  164. ... embedding = recognizer.get_normalized_embedding(image, landmarks)
  165. """
  166. original_height, original_width = image.shape[:2]
  167. image, resize_factor = resize_image(image, target_shape=self.input_size)
  168. image_tensor = self.preprocess(image)
  169. # ONNXRuntime inference
  170. outputs = self.inference(image_tensor)
  171. scores_list, bboxes_list, kpss_list = self.postprocess(outputs, image_size=image.shape[:2])
  172. # Handle case when no faces are detected
  173. if not scores_list:
  174. return []
  175. scores = np.vstack(scores_list)
  176. scores_ravel = scores.ravel()
  177. order = scores_ravel.argsort()[::-1]
  178. bboxes = np.vstack(bboxes_list) / resize_factor
  179. landmarks = np.vstack(kpss_list) / resize_factor
  180. pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
  181. pre_det = pre_det[order, :]
  182. keep = non_max_suppression(pre_det, threshold=self.nms_thresh)
  183. detections = pre_det[keep, :]
  184. landmarks = landmarks[order, :, :]
  185. landmarks = landmarks[keep, :, :].astype(np.int32)
  186. if 0 < max_num < detections.shape[0]:
  187. # Calculate area of detections
  188. area = (detections[:, 2] - detections[:, 0]) * (detections[:, 3] - detections[:, 1])
  189. # Calculate offsets from image center
  190. center = (original_height // 2, original_width // 2)
  191. offsets = np.vstack(
  192. [
  193. (detections[:, 0] + detections[:, 2]) / 2 - center[1],
  194. (detections[:, 1] + detections[:, 3]) / 2 - center[0],
  195. ]
  196. )
  197. # Calculate scores based on the chosen metric
  198. offset_dist_squared = np.sum(np.power(offsets, 2.0), axis=0)
  199. if metric == 'max':
  200. values = area
  201. else:
  202. values = area - offset_dist_squared * center_weight
  203. # Sort by scores and select top `max_num`
  204. sorted_indices = np.argsort(values)[::-1][:max_num]
  205. detections = detections[sorted_indices]
  206. landmarks = landmarks[sorted_indices]
  207. faces = []
  208. for i in range(detections.shape[0]):
  209. face_dict = {
  210. 'bbox': detections[i, :4].astype(np.float32),
  211. 'confidence': float(detections[i, 4]),
  212. 'landmarks': landmarks[i].astype(np.float32),
  213. }
  214. faces.append(face_dict)
  215. return faces
  216. # TODO: below is only for testing, remove it later
  217. def draw_bbox(frame, bbox, score, color=(0, 255, 0), thickness=2):
  218. x1, y1, x2, y2 = map(int, bbox) # Unpack 4 bbox values
  219. cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
  220. cv2.putText(frame, f'{score:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
  221. def draw_keypoints(frame, points, color=(0, 0, 255), radius=2):
  222. for x, y in points.astype(np.int32):
  223. cv2.circle(frame, (int(x), int(y)), radius, color, -1)
  224. if __name__ == '__main__':
  225. detector = SCRFD(model_name=SCRFDWeights.SCRFD_500M_KPS)
  226. print(detector.get_info())
  227. cap = cv2.VideoCapture(0)
  228. if not cap.isOpened():
  229. print('Failed to open webcam.')
  230. exit()
  231. print("Webcam started. Press 'q' to exit.")
  232. while True:
  233. ret, frame = cap.read()
  234. if not ret:
  235. print('Failed to read frame.')
  236. break
  237. # Get face detections as list of dictionaries
  238. faces = detector.detect(frame)
  239. # Process each detected face
  240. for face in faces:
  241. # Extract bbox and landmarks from dictionary
  242. bbox = face['bbox'] # [x1, y1, x2, y2]
  243. landmarks = face['landmarks'] # [[x1, y1], [x2, y2], ...]
  244. confidence = face['confidence']
  245. # Pass bbox and confidence separately
  246. draw_bbox(frame, bbox, confidence)
  247. # Convert landmarks to numpy array format if needed
  248. if landmarks is not None and len(landmarks) > 0:
  249. # Convert list of [x, y] pairs to numpy array
  250. points = np.array(landmarks, dtype=np.float32) # Shape: (5, 2)
  251. draw_keypoints(frame, points)
  252. # Display face count
  253. cv2.putText(
  254. frame,
  255. f'Faces: {len(faces)}',
  256. (10, 30),
  257. cv2.FONT_HERSHEY_SIMPLEX,
  258. 0.7,
  259. (255, 255, 255),
  260. 2,
  261. )
  262. cv2.imshow('FaceDetection', frame)
  263. if cv2.waitKey(1) & 0xFF == ord('q'):
  264. break
  265. cap.release()
  266. cv2.destroyAllWindows()