asrWebSocket.ts 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. /**
  2. * ASR WebSocket 流式识别客户端
  3. * 基于测试文件 test_asr_websocket_client.py 的协议实现
  4. */
  5. // 协议常量定义(与服务端保持一致)
  6. const ACTION_HEADER_SIZE = 18;
  7. const DEFAULT_AUDIO_CHUNK_SIZE = 15360;
  8. const MAX_PAYLOAD_SIZE = DEFAULT_AUDIO_CHUNK_SIZE * 2;
  9. const PROTOCOL_HEADER_FORMAT = ">18sI"; // 大端序: 18字节action + 4字节无符号整数payload_size
  10. const PROTOCOL_HEADER_SIZE = 22; // 18 + 4
  11. /**
  12. * 格式化action名称为18字节,右侧用空格填充
  13. */
  14. function formatAction(actionName: string): Uint8Array {
  15. if (actionName.length > ACTION_HEADER_SIZE) {
  16. throw new Error(`Action name '${actionName}' exceeds ${ACTION_HEADER_SIZE} bytes`);
  17. }
  18. const padded = actionName.padEnd(ACTION_HEADER_SIZE, ' ');
  19. return new TextEncoder().encode(padded);
  20. }
  21. // 动作类型定义
  22. export const ActionType = {
  23. // 客户端请求类型
  24. START_STREAM: formatAction("START_STREAM"),
  25. AUDIO_CHUNK: formatAction("AUDIO_CHUNK"),
  26. FINAL_CHUNK: formatAction("FINAL_CHUNK"),
  27. END_STREAM: formatAction("END_STREAM"),
  28. PING: formatAction("PING"),
  29. // 服务端响应类型
  30. CONNECTION_ACK: formatAction("CONNECTION_ACK"),
  31. ENGINE_READY: formatAction("ENGINE_READY"),
  32. STREAM_STARTED: formatAction("STREAM_STARTED"),
  33. PARTIAL_TRANSCRIPT: formatAction("PARTIAL_TRANSCRIPT"),
  34. FINAL_TRANSCRIPT: formatAction("FINAL_TRANSCRIPT"),
  35. STREAM_ENDED: formatAction("STREAM_ENDED"),
  36. ERROR: formatAction("ERROR"),
  37. PONG: formatAction("PONG"),
  38. };
  39. /**
  40. * 解析二进制消息,返回{action, payload}
  41. */
  42. function parseBinaryMessage(data: ArrayBuffer): { action: Uint8Array; payload: Uint8Array } {
  43. if (data.byteLength < PROTOCOL_HEADER_SIZE) {
  44. throw new Error(
  45. `Message too short: ${data.byteLength} bytes, expected at least ${PROTOCOL_HEADER_SIZE}`
  46. );
  47. }
  48. const view = new DataView(data);
  49. const action = new Uint8Array(data, 0, ACTION_HEADER_SIZE);
  50. const payloadSize = view.getUint32(ACTION_HEADER_SIZE, false); // 大端序
  51. const expectedTotalSize = PROTOCOL_HEADER_SIZE + payloadSize;
  52. if (data.byteLength !== expectedTotalSize) {
  53. throw new Error(
  54. `Message size mismatch: got ${data.byteLength} bytes, expected ${expectedTotalSize}`
  55. );
  56. }
  57. const payload = payloadSize > 0
  58. ? new Uint8Array(data, PROTOCOL_HEADER_SIZE, payloadSize)
  59. : new Uint8Array(0);
  60. return { action, payload };
  61. }
  62. /**
  63. * 创建二进制消息
  64. */
  65. function createBinaryMessage(action: Uint8Array, payload: Uint8Array = new Uint8Array(0)): ArrayBuffer {
  66. if (action.length !== ACTION_HEADER_SIZE) {
  67. throw new Error(
  68. `Action must be exactly ${ACTION_HEADER_SIZE} bytes, got ${action.length}`
  69. );
  70. }
  71. const payloadSize = payload.length;
  72. const totalSize = PROTOCOL_HEADER_SIZE + payloadSize;
  73. const buffer = new ArrayBuffer(totalSize);
  74. const view = new DataView(buffer);
  75. // 写入action (18字节)
  76. new Uint8Array(buffer, 0, ACTION_HEADER_SIZE).set(action);
  77. // 写入payload大小 (4字节,大端序)
  78. view.setUint32(ACTION_HEADER_SIZE, payloadSize, false);
  79. // 写入payload
  80. if (payloadSize > 0) {
  81. new Uint8Array(buffer, PROTOCOL_HEADER_SIZE, payloadSize).set(payload);
  82. }
  83. return buffer;
  84. }
  85. /**
  86. * 将文本编码为UTF-8字节
  87. */
  88. function encodeTextPayload(text: string): Uint8Array {
  89. return new TextEncoder().encode(text);
  90. }
  91. /**
  92. * 将字节解码为UTF-8文本
  93. */
  94. function decodeTextPayload(payload: Uint8Array): string {
  95. return payload.length > 0 ? new TextDecoder().decode(payload) : "";
  96. }
  97. /**
  98. * 比较两个Uint8Array是否相等
  99. */
  100. function arraysEqual(a: Uint8Array, b: Uint8Array): boolean {
  101. if (a.length !== b.length) return false;
  102. for (let i = 0; i < a.length; i++) {
  103. if (a[i] !== b[i]) return false;
  104. }
  105. return true;
  106. }
  107. /**
  108. * 音频录制器
  109. */
  110. export class AudioRecorder {
  111. private sampleRate: number;
  112. private channels: number;
  113. private chunkSize: number;
  114. private audioContext: AudioContext | null = null;
  115. private mediaStream: MediaStream | null = null;
  116. private audioWorkletNode: AudioWorkletNode | null = null;
  117. private isRecording = false;
  118. private targetChunkSize: number; // 服务器要求的音频块大小:240ms * 16000Hz * 2字节 = 15360字节
  119. private audioBuffer: number[] = [];
  120. private onAudioChunk?: (chunk: Uint8Array) => void;
  121. constructor(
  122. sampleRate = 16000,
  123. channels = 1,
  124. chunkSize = 1024,
  125. onAudioChunk?: (chunk: Uint8Array) => void
  126. ) {
  127. this.sampleRate = sampleRate;
  128. this.channels = channels;
  129. this.chunkSize = chunkSize;
  130. this.targetChunkSize = 7680 * 2; // 15360字节
  131. this.onAudioChunk = onAudioChunk;
  132. }
  133. async startRecording(): Promise<void> {
  134. try {
  135. // 获取麦克风权限
  136. this.mediaStream = await navigator.mediaDevices.getUserMedia({
  137. audio: {
  138. sampleRate: this.sampleRate,
  139. channelCount: this.channels,
  140. echoCancellation: true,
  141. noiseSuppression: true,
  142. }
  143. });
  144. // 创建AudioContext
  145. this.audioContext = new (window.AudioContext || (window as any).webkitAudioContext)({
  146. sampleRate: this.sampleRate
  147. });
  148. // 创建音频处理节点
  149. await this.audioContext.audioWorklet.addModule(
  150. URL.createObjectURL(new Blob([
  151. `
  152. class AudioProcessor extends AudioWorkletProcessor {
  153. process(inputs, outputs, parameters) {
  154. const input = inputs[0];
  155. if (input && input[0]) {
  156. // 将Float32Array转换为Int16Array
  157. const float32Data = input[0];
  158. const int16Data = new Int16Array(float32Data.length);
  159. for (let i = 0; i < float32Data.length; i++) {
  160. const sample = Math.max(-1, Math.min(1, float32Data[i]));
  161. int16Data[i] = sample < 0 ? sample * 0x8000 : sample * 0x7FFF;
  162. }
  163. this.port.postMessage(int16Data);
  164. }
  165. return true;
  166. }
  167. }
  168. registerProcessor('audio-processor', AudioProcessor);
  169. `
  170. ], { type: 'application/javascript' }))
  171. );
  172. const source = this.audioContext.createMediaStreamSource(this.mediaStream);
  173. this.audioWorkletNode = new AudioWorkletNode(this.audioContext, 'audio-processor');
  174. this.audioWorkletNode.port.onmessage = (event) => {
  175. if (this.isRecording) {
  176. this.processAudioData(event.data);
  177. }
  178. };
  179. source.connect(this.audioWorkletNode);
  180. this.isRecording = true;
  181. console.log(`开始录音: ${this.sampleRate}Hz, ${this.channels}通道`);
  182. } catch (error) {
  183. console.error('启动录音失败:', error);
  184. throw error;
  185. }
  186. }
  187. private processAudioData(int16Data: Int16Array): void {
  188. // 将Int16Array转换为字节数组
  189. const bytes = new Uint8Array(int16Data.length * 2);
  190. for (let i = 0; i < int16Data.length; i++) {
  191. const sample = int16Data[i];
  192. bytes[i * 2] = sample & 0xFF; // 低字节
  193. bytes[i * 2 + 1] = (sample >> 8) & 0xFF; // 高字节
  194. }
  195. // 添加到缓冲区
  196. for (let i = 0; i < bytes.length; i++) {
  197. this.audioBuffer.push(bytes[i]);
  198. }
  199. // 如果缓冲区达到目标大小,发送音频块
  200. while (this.audioBuffer.length >= this.targetChunkSize) {
  201. const chunk = new Uint8Array(this.audioBuffer.splice(0, this.targetChunkSize));
  202. if (this.onAudioChunk) {
  203. this.onAudioChunk(chunk);
  204. }
  205. }
  206. }
  207. getRemainingAudio(): Uint8Array | null {
  208. if (this.audioBuffer.length > 0) {
  209. const remainingData = new Uint8Array(this.audioBuffer);
  210. this.audioBuffer = [];
  211. // 如果剩余数据不足目标大小,用静音补足
  212. if (remainingData.length < this.targetChunkSize) {
  213. const silenceNeeded = this.targetChunkSize - remainingData.length;
  214. const paddedData = new Uint8Array(this.targetChunkSize);
  215. paddedData.set(remainingData);
  216. // 剩余部分已经是0(静音)
  217. console.log(`音频数据不足,补足静音: ${silenceNeeded} 字节`);
  218. return paddedData;
  219. }
  220. return remainingData;
  221. }
  222. return null;
  223. }
  224. stopRecording(): void {
  225. this.isRecording = false;
  226. if (this.audioWorkletNode) {
  227. this.audioWorkletNode.disconnect();
  228. this.audioWorkletNode = null;
  229. }
  230. if (this.mediaStream) {
  231. this.mediaStream.getTracks().forEach(track => track.stop());
  232. this.mediaStream = null;
  233. }
  234. if (this.audioContext) {
  235. this.audioContext.close();
  236. this.audioContext = null;
  237. }
  238. console.log('录音已停止');
  239. }
  240. cleanup(): void {
  241. this.stopRecording();
  242. }
  243. }
  244. /**
  245. * ASR WebSocket客户端事件接口
  246. */
  247. export interface ASRWebSocketEvents {
  248. onConnectionAck?: (message: string) => void;
  249. onEngineReady?: (message: string) => void;
  250. onStreamStarted?: (message: string) => void;
  251. onPartialTranscript?: (text: string) => void;
  252. onFinalTranscript?: (text: string) => void;
  253. onStreamEnded?: (message: string) => void;
  254. onError?: (error: string) => void;
  255. onPong?: () => void;
  256. }
  257. /**
  258. * ASR WebSocket客户端
  259. */
  260. export class ASRWebSocketClient {
  261. private serverUrl: string;
  262. private websocket: WebSocket | null = null;
  263. private audioRecorder: AudioRecorder | null = null;
  264. private isStreaming = false;
  265. private events: ASRWebSocketEvents;
  266. private finalTranscript = "";
  267. constructor(serverUrl = "ws://localhost:8880/adh/stream_asr/v0/engine", events: ASRWebSocketEvents = {}) {
  268. this.serverUrl = serverUrl;
  269. this.events = events;
  270. }
  271. async connect(): Promise<boolean> {
  272. try {
  273. console.log(`正在连接到服务器: ${this.serverUrl}`);
  274. this.websocket = new WebSocket(this.serverUrl);
  275. this.websocket.binaryType = 'arraybuffer';
  276. return new Promise((resolve, reject) => {
  277. if (!this.websocket) {
  278. reject(new Error('WebSocket创建失败'));
  279. return;
  280. }
  281. this.websocket.onopen = () => {
  282. console.log('WebSocket连接成功');
  283. this.setupMessageHandler();
  284. resolve(true);
  285. };
  286. this.websocket.onerror = (error) => {
  287. console.error('WebSocket连接失败:', error);
  288. reject(error);
  289. };
  290. this.websocket.onclose = () => {
  291. console.log('WebSocket连接已关闭');
  292. };
  293. });
  294. } catch (error) {
  295. console.error('连接失败:', error);
  296. return false;
  297. }
  298. }
  299. private setupMessageHandler(): void {
  300. if (!this.websocket) return;
  301. this.websocket.onmessage = (event) => {
  302. try {
  303. if (event.data instanceof ArrayBuffer) {
  304. const { action, payload } = parseBinaryMessage(event.data);
  305. this.handleServerMessage(action, payload);
  306. } else {
  307. console.error('收到非二进制消息:', event.data);
  308. }
  309. } catch (error) {
  310. console.error('解析消息失败:', error);
  311. }
  312. };
  313. }
  314. private handleServerMessage(action: Uint8Array, payload: Uint8Array): void {
  315. const messageText = decodeTextPayload(payload);
  316. if (arraysEqual(action, ActionType.CONNECTION_ACK)) {
  317. console.log(`服务器确认连接: ${messageText}`);
  318. this.events.onConnectionAck?.(messageText);
  319. } else if (arraysEqual(action, ActionType.ENGINE_READY)) {
  320. console.log(`ASR引擎就绪: ${messageText}`);
  321. this.events.onEngineReady?.(messageText);
  322. } else if (arraysEqual(action, ActionType.STREAM_STARTED)) {
  323. console.log(`音频流已开始: ${messageText}`);
  324. this.events.onStreamStarted?.(messageText);
  325. } else if (arraysEqual(action, ActionType.PARTIAL_TRANSCRIPT)) {
  326. console.log(`部分识别结果: ${messageText}`);
  327. this.events.onPartialTranscript?.(messageText);
  328. } else if (arraysEqual(action, ActionType.FINAL_TRANSCRIPT)) {
  329. console.log(`最终识别结果: ${messageText}`);
  330. this.finalTranscript = messageText;
  331. this.events.onFinalTranscript?.(messageText);
  332. } else if (arraysEqual(action, ActionType.STREAM_ENDED)) {
  333. console.log(`音频流已结束: ${messageText}`);
  334. this.events.onStreamEnded?.(messageText);
  335. } else if (arraysEqual(action, ActionType.PONG)) {
  336. console.log('收到PONG响应');
  337. this.events.onPong?.();
  338. } else if (arraysEqual(action, ActionType.ERROR)) {
  339. console.error(`服务器错误: ${messageText}`);
  340. this.events.onError?.(messageText);
  341. } else {
  342. const actionName = new TextDecoder().decode(action).trim();
  343. console.warn(`未知消息类型: ${actionName}`);
  344. }
  345. }
  346. isConnected(): boolean {
  347. return this.websocket !== null && this.websocket.readyState === WebSocket.OPEN;
  348. }
  349. async disconnect(): Promise<void> {
  350. if (this.websocket) {
  351. this.websocket.close();
  352. this.websocket = null;
  353. console.log('WebSocket连接已断开');
  354. }
  355. }
  356. private async sendMessage(action: Uint8Array, payload: Uint8Array = new Uint8Array(0)): Promise<boolean> {
  357. if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
  358. console.error('WebSocket未连接');
  359. return false;
  360. }
  361. try {
  362. const message = createBinaryMessage(action, payload);
  363. this.websocket.send(message);
  364. console.log(`发送消息: ${new TextDecoder().decode(action).trim()}`);
  365. return true;
  366. } catch (error) {
  367. console.error('发送消息失败:', error);
  368. return false;
  369. }
  370. }
  371. private async sendAudioChunk(audioData: Uint8Array, isFinal = false): Promise<boolean> {
  372. if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
  373. console.error('WebSocket未连接');
  374. return false;
  375. }
  376. try {
  377. const action = isFinal ? ActionType.FINAL_CHUNK : ActionType.AUDIO_CHUNK;
  378. const message = createBinaryMessage(action, audioData);
  379. this.websocket.send(message);
  380. console.log(`发送音频块: ${audioData.length} 字节 ${isFinal ? '(最终块)' : '(普通块)'}`);
  381. return true;
  382. } catch (error) {
  383. console.error('发送音频数据失败:', error);
  384. return false;
  385. }
  386. }
  387. async startAudioStream(): Promise<boolean> {
  388. // 发送开始流消息
  389. if (!await this.sendMessage(ActionType.START_STREAM)) {
  390. return false;
  391. }
  392. // 等待一下确保服务器准备好
  393. await new Promise(resolve => setTimeout(resolve, 100));
  394. // 创建音频录制器
  395. this.audioRecorder = new AudioRecorder(
  396. 16000, // 采样率
  397. 1, // 单声道
  398. 1024, // 块大小
  399. (chunk) => {
  400. if (this.isStreaming) {
  401. this.sendAudioChunk(chunk);
  402. }
  403. }
  404. );
  405. // 启动录音
  406. await this.audioRecorder.startRecording();
  407. this.isStreaming = true;
  408. console.log('音频流已启动');
  409. return true;
  410. }
  411. async stopAudioStream(): Promise<void> {
  412. this.isStreaming = false;
  413. // 发送剩余的音频数据作为最终块
  414. if (this.audioRecorder) {
  415. const remainingAudio = this.audioRecorder.getRemainingAudio();
  416. if (remainingAudio && remainingAudio.length > 0) {
  417. console.log(`发送剩余音频数据: ${remainingAudio.length} 字节`);
  418. await this.sendAudioChunk(remainingAudio, true);
  419. }
  420. // 停止录音
  421. this.audioRecorder.stopRecording();
  422. this.audioRecorder = null;
  423. }
  424. // 发送结束流消息
  425. await this.sendMessage(ActionType.END_STREAM);
  426. console.log('音频流已停止');
  427. }
  428. async ping(payload = "test_ping"): Promise<boolean> {
  429. return await this.sendMessage(ActionType.PING, encodeTextPayload(payload));
  430. }
  431. getFinalTranscript(): string {
  432. return this.finalTranscript;
  433. }
  434. clearFinalTranscript(): void {
  435. this.finalTranscript = "";
  436. }
  437. }