onnx_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2025 Yakhyokhuja Valikhujaev
  2. # Author: Yakhyokhuja Valikhujaev
  3. # GitHub: https://github.com/yakhyo
  4. from typing import List
  5. import onnxruntime as ort
  6. from uniface.log import Logger
  7. def get_available_providers() -> List[str]:
  8. """
  9. Get list of available ONNX Runtime execution providers for the current platform.
  10. Automatically detects and prioritizes hardware acceleration:
  11. - CoreML on Apple Silicon (M1/M2/M3/M4)
  12. - CUDA on NVIDIA GPUs
  13. - CPU as fallback (always available)
  14. Returns:
  15. List[str]: Ordered list of execution providers to use
  16. Examples:
  17. >>> providers = get_available_providers()
  18. >>> # On M4 Mac: ['CoreMLExecutionProvider', 'CPUExecutionProvider']
  19. >>> # On Linux with CUDA: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  20. >>> # On CPU-only: ['CPUExecutionProvider']
  21. """
  22. available = ort.get_available_providers()
  23. providers = []
  24. # Priority order: CoreML > CUDA > CPU
  25. if 'CoreMLExecutionProvider' in available:
  26. providers.append('CoreMLExecutionProvider')
  27. Logger.info('CoreML acceleration enabled (Apple Silicon)')
  28. if 'CUDAExecutionProvider' in available:
  29. providers.append('CUDAExecutionProvider')
  30. Logger.info('CUDA acceleration enabled (NVIDIA GPU)')
  31. # CPU is always available as fallback
  32. providers.append('CPUExecutionProvider')
  33. if len(providers) == 1:
  34. Logger.info('Using CPU execution (no hardware acceleration detected)')
  35. return providers
  36. def create_onnx_session(model_path: str, providers: List[str] = None) -> ort.InferenceSession:
  37. """
  38. Create an ONNX Runtime inference session with optimal provider selection.
  39. Args:
  40. model_path (str): Path to the ONNX model file
  41. providers (List[str], optional): List of providers to use.
  42. If None, automatically detects best available providers.
  43. Returns:
  44. ort.InferenceSession: Configured ONNX Runtime session
  45. Raises:
  46. RuntimeError: If session creation fails
  47. Examples:
  48. >>> session = create_onnx_session("model.onnx")
  49. >>> # Automatically uses best available providers
  50. >>> session = create_onnx_session("model.onnx", providers=["CPUExecutionProvider"])
  51. >>> # Force CPU-only execution
  52. """
  53. if providers is None:
  54. providers = get_available_providers()
  55. # Suppress ONNX Runtime warnings (e.g., CoreML partition warnings)
  56. # Log levels: 0=VERBOSE, 1=INFO, 2=WARNING, 3=ERROR, 4=FATAL
  57. sess_options = ort.SessionOptions()
  58. sess_options.log_severity_level = 3 # Only show ERROR and FATAL
  59. try:
  60. session = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
  61. active_provider = session.get_providers()[0]
  62. Logger.debug(f'Session created with provider: {active_provider}')
  63. # Show user-friendly message about which provider is being used
  64. provider_names = {
  65. 'CoreMLExecutionProvider': 'CoreML (Apple Silicon)',
  66. 'CUDAExecutionProvider': 'CUDA (NVIDIA GPU)',
  67. 'CPUExecutionProvider': 'CPU',
  68. }
  69. provider_display = provider_names.get(active_provider, active_provider)
  70. Logger.debug(f'Model loaded with provider: {active_provider}')
  71. print(f'✓ Model loaded ({provider_display})')
  72. return session
  73. except Exception as e:
  74. Logger.error(f'Failed to create ONNX session: {e}', exc_info=True)
  75. raise RuntimeError(f'Failed to initialize ONNX Runtime session: {e}') from e