websocket.ts 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import { getWsUrl } from "./requests";
  2. // 协议常量定义(与服务端保持一致)
  3. const ACTION_HEADER_SIZE = 18;
  4. const PROTOCOL_HEADER_FORMAT = ">18sI"; // 大端序: 18字节action + 4字节无符号整数payload_size
  5. const PROTOCOL_HEADER_SIZE = 22; // 18 + 4
  6. export enum WS_SEND_ACTION_TYPE {
  7. "PING" = "PING", // 心跳包
  8. "ENGINE_START" = "ENGINE_START", // 启动引擎
  9. "ENGINE_PARTIAL_INPUT" = "PARTIAL_INPUT", // 引擎输入
  10. "ENGINE_FINAL_INPUT" = "FINAL_INPUT", // 引擎输入
  11. "ENGINE_STOP" = "ENGINE_STOP", // 停止引擎
  12. }
  13. export enum WS_RECV_ACTION_TYPE {
  14. "PONG" = "PONG", // 心跳响应
  15. "ENGINE_INITIALZING" = "ENGINE_INITIALZING", // 引擎初始化
  16. "ENGINE_STARTED" = "ENGINE_STARTED", // 引擎准备就绪
  17. "ENGINE_PARTIAL_OUTPUT" = "PARTIAL_OUTPUT", // 引擎输出
  18. "ENGINE_FINAL_OUTPUT" = "FINAL_OUTPUT", // 引擎输出
  19. "ENGINE_STOPPED" = "ENGINE_STOPPED", // 关闭引擎
  20. "ERROR" = "ERROR", // 错误响应
  21. }
  22. /**
  23. * 格式化action名称为18字节,右侧用空格填充
  24. */
  25. function _format_action(actionName: string): Uint8Array {
  26. if (actionName.length > ACTION_HEADER_SIZE) {
  27. throw new Error(`Action name '${actionName}' exceeds ${ACTION_HEADER_SIZE} bytes`);
  28. }
  29. const padded = actionName.padEnd(ACTION_HEADER_SIZE, ' ');
  30. return new TextEncoder().encode(padded);
  31. }
  32. /**
  33. * 解析二进制消息,返回{action, payload}
  34. */
  35. function parse_message(data: ArrayBuffer): { action: string; payload: Uint8Array } {
  36. if (data.byteLength < PROTOCOL_HEADER_SIZE) {
  37. throw new Error(
  38. `Message too short: ${data.byteLength} bytes, expected at least ${PROTOCOL_HEADER_SIZE}`
  39. );
  40. }
  41. const view = new DataView(data);
  42. const action = new Uint8Array(data, 0, ACTION_HEADER_SIZE);
  43. const payloadSize = view.getUint32(ACTION_HEADER_SIZE, false); // 大端序
  44. const expectedTotalSize = PROTOCOL_HEADER_SIZE + payloadSize;
  45. if (data.byteLength !== expectedTotalSize) {
  46. throw new Error(
  47. `Message size mismatch: got ${data.byteLength} bytes, expected ${expectedTotalSize}`
  48. );
  49. }
  50. const payload = payloadSize > 0
  51. ? new Uint8Array(data, PROTOCOL_HEADER_SIZE, payloadSize)
  52. : new Uint8Array(0);
  53. return { action: new TextDecoder().decode(action).trim(), payload: payload };
  54. }
  55. /**
  56. * 创建二进制消息
  57. */
  58. function struct_message(action: string, payload: string | Uint8Array = new Uint8Array(0)): ArrayBuffer {
  59. // 判断pauload类型
  60. if (typeof payload === 'string') {
  61. payload = new TextEncoder().encode(payload);
  62. }
  63. const actionData = _format_action(action);
  64. if (actionData.length !== ACTION_HEADER_SIZE) {
  65. throw new Error(
  66. `Action must be exactly ${ACTION_HEADER_SIZE} bytes, got ${actionData.length}`
  67. );
  68. }
  69. const payloadSize = payload.length;
  70. const totalSize = PROTOCOL_HEADER_SIZE + payloadSize;
  71. const buffer = new ArrayBuffer(totalSize);
  72. const view = new DataView(buffer);
  73. // 写入action (18字节)
  74. new Uint8Array(buffer, 0, ACTION_HEADER_SIZE).set(actionData);
  75. // 写入payload大小 (4字节,大端序)
  76. view.setUint32(ACTION_HEADER_SIZE, payloadSize, false);
  77. // 写入payload
  78. if (payloadSize > 0) {
  79. new Uint8Array(buffer, PROTOCOL_HEADER_SIZE, payloadSize).set(payload);
  80. }
  81. return buffer;
  82. }
  83. export class WebsocketClient {
  84. private _ws: WebSocket | null = null;
  85. private _url: string;
  86. private _engine: string;
  87. private _config: {};
  88. private _onMessage?: (action: string, data: Uint8Array) => void;
  89. private _onOpen?: () => void;
  90. private _onClose?: () => void;
  91. private _onError?: (error: Error) => void;
  92. constructor(
  93. url: string,
  94. engine: string,
  95. config: {},
  96. onMessage?: (action: string, data: Uint8Array) => void,
  97. onOpen?: () => void,
  98. onClose?: () => void,
  99. onError?: (error: Error) => void
  100. ) {
  101. this._url = url;
  102. this._engine = engine;
  103. this._config = config;
  104. this._onMessage = onMessage;
  105. this._onOpen = onOpen;
  106. this._onClose = onClose;
  107. this._onError = onError;
  108. }
  109. public connect() {
  110. this._ws = new WebSocket(this._url);
  111. this._ws.binaryType = 'arraybuffer';
  112. this._ws.onopen = () => {
  113. const payload = JSON.stringify({
  114. engine: this._engine,
  115. config: this._config
  116. });
  117. this.sendMessage(WS_SEND_ACTION_TYPE.ENGINE_START, payload);
  118. if (this._onOpen) {
  119. this._onOpen();
  120. }
  121. };
  122. this._ws.onmessage = (event) => {
  123. if (this._onMessage) {
  124. const { action, payload } = parse_message(event.data);
  125. this._onMessage(action as string, payload);
  126. }
  127. };
  128. this._ws.onclose = () => {
  129. this.sendMessage(WS_SEND_ACTION_TYPE.ENGINE_STOP);
  130. if (this._onClose) {
  131. this._onClose();
  132. }
  133. };
  134. this._ws.onerror = (error) => {
  135. if (this._onError) {
  136. this._onError(new Error(`WebSocket error: ${error.target}`));
  137. }
  138. }
  139. }
  140. public disconnect() {
  141. if (this._ws) {
  142. this._ws.close();
  143. this._ws = null;
  144. }
  145. }
  146. public isConnected() {
  147. return this._ws && this._ws.readyState === WebSocket.OPEN;
  148. }
  149. public sendMessage(action: string, payload: string | Uint8Array = new Uint8Array(0)) {
  150. const data = struct_message(action, payload);
  151. if (this._ws && this._ws.readyState === WebSocket.OPEN) {
  152. this._ws.send(data);
  153. } else {
  154. // throw new Error('WebSocket is not connected');
  155. }
  156. }
  157. }
  158. export function createASRWebsocketClient(
  159. events: {
  160. engine: string,
  161. config: {},
  162. onMessage?: (action: string, data: Uint8Array) => void,
  163. onOpen?: () => void,
  164. onClose?: () => void,
  165. onError?: (error: Error) => void
  166. }
  167. ) {
  168. const path = `/adh/asr/v0/engine/stream`;
  169. const url = getWsUrl(path);
  170. return new WebsocketClient(
  171. url,
  172. events.engine,
  173. events.config,
  174. events.onMessage,
  175. events.onOpen,
  176. events.onClose,
  177. events.onError
  178. );
  179. }