emotion.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from typing import List, Tuple, Union
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from uniface.attribute.base import Attribute
  9. from uniface.constants import DDAMFNWeights
  10. from uniface.face_utils import face_alignment
  11. from uniface.log import Logger
  12. from uniface.model_store import verify_model_weights
  13. __all__ = ['Emotion']
  14. class Emotion(Attribute):
  15. """
  16. Emotion recognition model using a TorchScript model.
  17. This class inherits from the base `Attribute` class and implements the
  18. functionality for predicting one of several emotion categories from a face
  19. image. It requires 5-point facial landmarks for alignment.
  20. """
  21. def __init__(
  22. self,
  23. model_weights: DDAMFNWeights = DDAMFNWeights.AFFECNET7,
  24. input_size: Tuple[int, int] = (112, 112),
  25. ) -> None:
  26. """
  27. Initializes the emotion recognition model.
  28. Args:
  29. model_weights (DDAMFNWeights): The enum for the model weights to load.
  30. input_size (Tuple[int, int]): The expected input size for the model.
  31. """
  32. Logger.info(f'Initializing Emotion with model={model_weights.name}')
  33. if torch.backends.mps.is_available():
  34. self.device = torch.device('mps')
  35. elif torch.cuda.is_available():
  36. self.device = torch.device('cuda')
  37. else:
  38. self.device = torch.device('cpu')
  39. self.input_size = input_size
  40. self.model_path = verify_model_weights(model_weights)
  41. # Define emotion labels based on the selected model
  42. self.emotion_labels = [
  43. 'Neutral',
  44. 'Happy',
  45. 'Sad',
  46. 'Surprise',
  47. 'Fear',
  48. 'Disgust',
  49. 'Angry',
  50. ]
  51. if model_weights == DDAMFNWeights.AFFECNET8:
  52. self.emotion_labels.append('Contempt')
  53. self._initialize_model()
  54. def _initialize_model(self) -> None:
  55. """
  56. Loads and initializes the TorchScript model for inference.
  57. """
  58. try:
  59. self.model = torch.jit.load(self.model_path, map_location=self.device)
  60. self.model.eval()
  61. # Warm-up with a dummy input for faster first inference
  62. dummy_input = torch.randn(1, 3, *self.input_size).to(self.device)
  63. with torch.no_grad():
  64. self.model(dummy_input)
  65. Logger.info(f'Successfully initialized Emotion model on {self.device}')
  66. except Exception as e:
  67. Logger.error(f"Failed to load Emotion model from '{self.model_path}'", exc_info=True)
  68. raise RuntimeError(f'Failed to initialize Emotion model: {e}') from e
  69. def preprocess(self, image: np.ndarray, landmark: Union[List, np.ndarray]) -> torch.Tensor:
  70. """
  71. Aligns the face using landmarks and preprocesses it into a tensor.
  72. Args:
  73. image (np.ndarray): The full input image in BGR format.
  74. landmark (Union[List, np.ndarray]): The 5-point facial landmarks.
  75. Returns:
  76. torch.Tensor: The preprocessed image tensor ready for inference.
  77. """
  78. landmark = np.asarray(landmark)
  79. aligned_image, _ = face_alignment(image, landmark)
  80. # Convert BGR to RGB, resize, normalize, and convert to a CHW tensor
  81. rgb_image = cv2.cvtColor(aligned_image, cv2.COLOR_BGR2RGB)
  82. resized_image = cv2.resize(rgb_image, self.input_size).astype(np.float32) / 255.0
  83. mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
  84. std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
  85. normalized_image = (resized_image - mean) / std
  86. transposed_image = normalized_image.transpose((2, 0, 1))
  87. return torch.from_numpy(transposed_image).unsqueeze(0).to(self.device)
  88. def postprocess(self, prediction: torch.Tensor) -> Tuple[str, float]:
  89. """
  90. Processes the raw model output to get the emotion label and confidence score.
  91. """
  92. probabilities = torch.nn.functional.softmax(prediction, dim=1).squeeze().cpu().numpy()
  93. pred_index = np.argmax(probabilities)
  94. emotion_label = self.emotion_labels[pred_index]
  95. confidence = float(probabilities[pred_index])
  96. return emotion_label, confidence
  97. def predict(self, image: np.ndarray, landmark: Union[List, np.ndarray]) -> Tuple[str, float]:
  98. """
  99. Predicts the emotion from a single face specified by its landmarks.
  100. """
  101. input_tensor = self.preprocess(image, landmark)
  102. with torch.no_grad():
  103. output = self.model(input_tensor)
  104. if isinstance(output, tuple):
  105. output = output[0]
  106. return self.postprocess(output)
  107. # TODO: below is only for testing, remove it later
  108. if __name__ == '__main__':
  109. from uniface.constants import RetinaFaceWeights
  110. from uniface.detection import create_detector
  111. print('Initializing models for live inference...')
  112. # 1. Initialize the face detector
  113. # Using a smaller model for faster real-time performance
  114. detector = create_detector(model_name=RetinaFaceWeights.MNET_V2)
  115. # 2. Initialize the attribute predictor
  116. emotion_predictor = Emotion()
  117. # 3. Start webcam capture
  118. cap = cv2.VideoCapture(0)
  119. if not cap.isOpened():
  120. print('Error: Could not open webcam.')
  121. exit()
  122. print("Starting webcam feed. Press 'q' to quit.")
  123. while True:
  124. ret, frame = cap.read()
  125. if not ret:
  126. print('Error: Failed to capture frame.')
  127. break
  128. # Detect faces in the current frame.
  129. # This method returns a list of dictionaries for each detected face.
  130. detections = detector.detect(frame)
  131. # For each detected face, predict the emotion
  132. for detection in detections:
  133. box = detection['bbox']
  134. landmark = detection['landmarks']
  135. x1, y1, x2, y2 = map(int, box)
  136. # Predict attributes using the landmark
  137. emotion, confidence = emotion_predictor.predict(frame, landmark)
  138. # Prepare text and draw on the frame
  139. label = f'{emotion} ({confidence:.2f})'
  140. cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
  141. cv2.putText(
  142. frame,
  143. label,
  144. (x1, y1 - 10),
  145. cv2.FONT_HERSHEY_SIMPLEX,
  146. 0.8,
  147. (255, 0, 0),
  148. 2,
  149. )
  150. # Display the resulting frame
  151. cv2.imshow("Emotion Inference (Press 'q' to quit)", frame)
  152. # Break the loop if 'q' is pressed
  153. if cv2.waitKey(1) & 0xFF == ord('q'):
  154. break
  155. # Release resources
  156. cap.release()
  157. cv2.destroyAllWindows()
  158. print('Inference stopped.')