model_store.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. import hashlib
  5. import os
  6. import requests
  7. from tqdm import tqdm
  8. import uniface.constants as const
  9. from uniface.log import Logger
  10. __all__ = ['verify_model_weights']
  11. def verify_model_weights(model_name: str, root: str = '~/.uniface/models') -> str:
  12. """
  13. Ensure model weights are present, downloading and verifying them using SHA-256 if necessary.
  14. Given a model identifier from an Enum class (e.g., `RetinaFaceWeights.MNET_V2`), this function checks if
  15. the corresponding `.onnx` weight file exists locally. If not, it downloads the file from a predefined URL.
  16. After download, the file’s integrity is verified using a SHA-256 hash. If verification fails, the file is deleted
  17. and an error is raised.
  18. Args:
  19. model_name (Enum): Model weight identifier (e.g., `RetinaFaceWeights.MNET_V2`, `ArcFaceWeights.RESNET`, etc.).
  20. root (str, optional): Directory to store or locate the model weights. Defaults to '~/.uniface/models'.
  21. Returns:
  22. str: Absolute path to the verified model weights file.
  23. Raises:
  24. ValueError: If the model is unknown or SHA-256 verification fails.
  25. ConnectionError: If downloading the file fails.
  26. Examples:
  27. >>> from uniface.models import RetinaFaceWeights, verify_model_weights
  28. >>> verify_model_weights(RetinaFaceWeights.MNET_V2)
  29. '/home/user/.uniface/models/retinaface_mnet_v2.onnx'
  30. >>> verify_model_weights(RetinaFaceWeights.RESNET34, root='/custom/dir')
  31. '/custom/dir/retinaface_r34.onnx'
  32. """
  33. root = os.path.expanduser(root)
  34. os.makedirs(root, exist_ok=True)
  35. # Keep model_name as enum for dictionary lookup
  36. url = const.MODEL_URLS.get(model_name)
  37. if not url:
  38. Logger.error(f"No URL found for model '{model_name}'")
  39. raise ValueError(f"No URL found for model '{model_name}'")
  40. file_ext = os.path.splitext(url)[1]
  41. model_path = os.path.normpath(os.path.join(root, f'{model_name.value}{file_ext}'))
  42. if not os.path.exists(model_path):
  43. Logger.info(f"Downloading model '{model_name}' from {url}")
  44. try:
  45. download_file(url, model_path)
  46. Logger.info(f"Successfully downloaded '{model_name}' to {model_path}")
  47. except Exception as e:
  48. Logger.error(f"Failed to download model '{model_name}': {e}")
  49. raise ConnectionError(f"Download failed for '{model_name}'") from e
  50. expected_hash = const.MODEL_SHA256.get(model_name)
  51. if expected_hash and not verify_file_hash(model_path, expected_hash):
  52. os.remove(model_path) # Remove corrupted file
  53. Logger.warning('Corrupted weight detected. Removing...')
  54. raise ValueError(f"Hash mismatch for '{model_name}'. The file may be corrupted; please try downloading again.")
  55. return model_path
  56. def download_file(url: str, dest_path: str) -> None:
  57. """Download a file from a URL in chunks and save it to the destination path."""
  58. try:
  59. response = requests.get(url, stream=True)
  60. response.raise_for_status()
  61. with (
  62. open(dest_path, 'wb') as file,
  63. tqdm(
  64. desc=f'Downloading {dest_path}',
  65. unit='B',
  66. unit_scale=True,
  67. unit_divisor=1024,
  68. ) as progress,
  69. ):
  70. for chunk in response.iter_content(chunk_size=const.CHUNK_SIZE):
  71. if chunk:
  72. file.write(chunk)
  73. progress.update(len(chunk))
  74. except requests.RequestException as e:
  75. raise ConnectionError(f'Failed to download file from {url}. Error: {e}') from e
  76. def verify_file_hash(file_path: str, expected_hash: str) -> bool:
  77. """Compute the SHA-256 hash of the file and compare it with the expected hash."""
  78. file_hash = hashlib.sha256()
  79. with open(file_path, 'rb') as f:
  80. for chunk in iter(lambda: f.read(const.CHUNK_SIZE), b''):
  81. file_hash.update(chunk)
  82. actual_hash = file_hash.hexdigest()
  83. if actual_hash != expected_hash:
  84. Logger.warning(f'Expected hash: {expected_hash}, but got: {actual_hash}')
  85. return actual_hash == expected_hash
  86. if __name__ == '__main__':
  87. model_names = [model.value for model in const.RetinaFaceWeights]
  88. # Download each model in the list
  89. for model_name in model_names:
  90. model_path = verify_model_weights(model_name)