diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/README.md b/ai_agents/agents/ten_packages/extension/camb_tts/README.md new file mode 100644 index 0000000000..5d6c23214b --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/README.md @@ -0,0 +1,58 @@ +# camb_tts + +Camb.ai TTS extension for TEN Framework using the MARS text-to-speech API. + +## Features + +- MARS model family (mars-flash, mars-pro, mars-instruct) +- 140+ languages supported +- Voice cloning capabilities +- Real-time HTTP streaming +- Model-specific sample rates (22.05kHz / 48kHz) + +## API + +Refer to `api` definition in [manifest.json](manifest.json) and default values in [property.json](property.json). + +### Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| api_key | string | (required) | Camb.ai API key | +| voice_id | int32 | 147320 | Voice ID | +| language | string | "en-us" | Language code (BCP-47 format) | +| speech_model | string | "mars-flash" | Model selection | +| user_instructions | string | (optional) | Instructions for mars-instruct model | +| format | string | "pcm_s16le" | Output format | + +### Available Models + +| Model | Sample Rate | Description | +|-------|-------------|-------------| +| `mars-flash` | 22.05kHz | Fast inference (default) | +| `mars-pro` | 48kHz | High quality | +| `mars-instruct` | 22.05kHz | Supports user instructions | + +## Development + +### Setup + +1. Get your API key from [Camb.ai](https://camb.ai) +2. Set environment variable: + ```bash + export CAMB_API_KEY=your_key_here + ``` + +### Build + +Follow the standard TEN Framework extension build process. + +### Unit test + +Run tests using the standard TEN Framework testing approach. + +## Resources + +- [Camb.ai API Documentation](https://camb.mintlify.app/) +- [Getting Started](https://camb.mintlify.app/getting-started) +- [API Reference](https://camb.mintlify.app/api-reference/endpoint/create-tts-stream) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/__init__.py b/ai_agents/agents/ten_packages/extension/camb_tts/__init__.py new file mode 100644 index 0000000000..72593ab225 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/addon.py b/ai_agents/agents/ten_packages/extension/camb_tts/addon.py new file mode 100644 index 0000000000..80a637fc0a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/addon.py @@ -0,0 +1,20 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("camb_tts") +class CambTTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import CambTTSExtension + + ten_env.log_info("CambTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(CambTTSExtension(name), context) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/camb_tts.py b/ai_agents/agents/ten_packages/extension/camb_tts/camb_tts.py new file mode 100644 index 0000000000..4d690f8100 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/camb_tts.py @@ -0,0 +1,195 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any, AsyncIterator, Tuple + +import aiohttp + +from .config import CambTTSConfig +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_ai_base.struct import TTS2HttpResponseEventType +from ten_ai_base.tts2_http import AsyncTTS2HttpClient + + +BYTES_PER_SAMPLE = 2 +NUMBER_OF_CHANNELS = 1 + +# Model-specific sample rates (matching livekit) +MODEL_SAMPLE_RATES: dict[str, int] = { + "mars-flash": 22050, + "mars-pro": 48000, + "mars-instruct": 22050, +} + +# Defaults matching livekit +DEFAULT_VOICE_ID = 147320 +DEFAULT_MODEL = "mars-flash" +DEFAULT_LANGUAGE = "en-us" + +API_BASE_URL = "https://client.camb.ai/apis" +API_KEY_HEADER = "x-api-key" + + +class CambTTSClient(AsyncTTS2HttpClient): + def __init__( + self, + config: CambTTSConfig, + ten_env: AsyncTenEnv, + ): + super().__init__() + self.config = config + self.api_key = config.params.get("api_key", "") + self.ten_env: AsyncTenEnv = ten_env + self._is_cancelled = False + try: + self._session = aiohttp.ClientSession() + except Exception: + self._session = None + + async def cancel(self): + self.ten_env.log_debug("CambTTS: cancel() called.") + self._is_cancelled = True + + async def get( + self, text: str, request_id: str + ) -> AsyncIterator[Tuple[bytes | None, TTS2HttpResponseEventType]]: + """Process a single TTS request using raw HTTP (like livekit).""" + self._is_cancelled = False + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"CambTTS: empty text for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + # Validate text length (Camb.ai requires 3-3000 characters) + text_len = len(text.strip()) + if text_len < 3: + self.ten_env.log_warn( + f"CambTTS: text too short ({text_len} chars, min 3) for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + if text_len > 3000: + error_message = f"CambTTS: text too long ({text_len} chars, max 3000) for request_id: {request_id}." + self.ten_env.log_error( + error_message, + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode("utf-8"), TTS2HttpResponseEventType.ERROR + return + + try: + speech_model = self.config.params.get("speech_model", DEFAULT_MODEL) + voice_id = self.config.params.get("voice_id", DEFAULT_VOICE_ID) + language = self.config.params.get("language", DEFAULT_LANGUAGE) + output_format = self.config.params.get("format", "pcm_s16le") + + # Build payload (same structure as livekit) + payload: dict = { + "text": text, + "voice_id": voice_id, + "language": language, + "speech_model": speech_model, + "output_configuration": { + "format": output_format, + }, + } + + # Add user_instructions only for mars-instruct model + user_instructions = self.config.params.get("user_instructions") + if speech_model == "mars-instruct" and user_instructions: + payload["user_instructions"] = user_instructions + + headers: dict[str, str] = {"Content-Type": "application/json"} + if self.api_key: + headers[API_KEY_HEADER] = self.api_key + + self.ten_env.log_debug( + f"CambTTS: requesting voice_id={voice_id}, model={speech_model}, format={output_format} for request_id: {request_id}." + ) + + async with self._session.post( + f"{API_BASE_URL}/tts-stream", + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=60), + ) as resp: + if resp.status in (401, 403): + error_message = "Invalid Camb.ai API key." + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.INVALID_KEY_ERROR + return + + if resp.status != 200: + content = await resp.text() + error_message = f"API Error {resp.status}: {content}" + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR + return + + # Stream audio chunks (same as livekit: resp.content.iter_chunks()) + async for data, _ in resp.content.iter_chunks(): + if self._is_cancelled: + self.ten_env.log_debug( + f"CambTTS: cancellation detected for request_id: {request_id}." + ) + yield None, TTS2HttpResponseEventType.FLUSH + break + + if data and len(data) > 0: + self.ten_env.log_debug( + f"CambTTS: received {len(data)} bytes for request_id: {request_id}." + ) + yield bytes(data), TTS2HttpResponseEventType.RESPONSE + + if not self._is_cancelled: + self.ten_env.log_debug( + f"CambTTS: stream complete for request_id: {request_id}." + ) + yield None, TTS2HttpResponseEventType.END + + except Exception as e: + error_message = str(e) + self.ten_env.log_error( + f"CambTTS error: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode("utf-8"), TTS2HttpResponseEventType.ERROR + + async def clean(self): + self.ten_env.log_debug("CambTTS: clean() called.") + if self._session: + await self._session.close() + self._session = None + + def get_extra_metadata(self) -> dict[str, Any]: + """Return extra metadata for TTFB metrics.""" + return { + "voice_id": self.config.params.get("voice_id", DEFAULT_VOICE_ID), + "speech_model": self.config.params.get( + "speech_model", DEFAULT_MODEL + ), + } + + def get_sample_rate(self) -> int: + """Return the sample rate based on the selected model.""" + speech_model = self.config.params.get("speech_model", DEFAULT_MODEL) + return MODEL_SAMPLE_RATES.get(speech_model, 22050) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/config.py b/ai_agents/agents/ten_packages/extension/camb_tts/config.py new file mode 100644 index 0000000000..8bb93a7864 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/config.py @@ -0,0 +1,54 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any +import copy +from pathlib import Path +from ten_ai_base import utils +from ten_ai_base.tts2_http import AsyncTTS2HttpConfig + +from pydantic import Field + + +class CambTTSConfig(AsyncTTS2HttpConfig): + """Camb.ai TTS Config""" + + # Debug and logging + dump: bool = Field(default=False, description="Camb TTS dump") + dump_path: str = Field( + default_factory=lambda: str(Path(__file__).parent / "camb_tts_in.pcm"), + description="Camb TTS dump path", + ) + params: dict[str, Any] = Field(default_factory=dict, description="Camb TTS params") + + def update_params(self) -> None: + """Update configuration from params dictionary""" + # Keys to exclude from params after processing (not passthrough params) + blacklist_keys = [ + "text", + ] + + # Remove blacklisted keys from params + for key in blacklist_keys: + if key in self.params: + del self.params[key] + + def to_str(self, sensitive_handling: bool = True) -> str: + """Convert config to string with optional sensitive data handling.""" + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + + # Encrypt sensitive fields in params + if config.params and "api_key" in config.params: + config.params["api_key"] = utils.encrypt(config.params["api_key"]) + + return f"{config}" + + def validate(self) -> None: + """Validate Camb-specific configuration.""" + if "api_key" not in self.params or not self.params["api_key"]: + raise ValueError("API key is required for Camb TTS") diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/extension.py b/ai_agents/agents/ten_packages/extension/camb_tts/extension.py new file mode 100644 index 0000000000..412b0f8eaf --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/extension.py @@ -0,0 +1,72 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +Camb.ai TTS Extension + +This extension implements text-to-speech using the Camb.ai MARS TTS API. +It extends the AsyncTTS2HttpExtension for HTTP-based TTS services. + +Models: + - mars-flash: Fast inference, 22.05kHz output (default) + - mars-pro: High quality, 48kHz output + - mars-instruct: Supports user_instructions, 22.05kHz output +""" + +from ten_ai_base.tts2_http import ( + AsyncTTS2HttpExtension, + AsyncTTS2HttpConfig, + AsyncTTS2HttpClient, +) +from ten_runtime import AsyncTenEnv + +from .config import CambTTSConfig +from .camb_tts import CambTTSClient, MODEL_SAMPLE_RATES, DEFAULT_MODEL + + +class CambTTSExtension(AsyncTTS2HttpExtension): + """ + Camb.ai TTS Extension implementation. + + Provides text-to-speech synthesis using Camb.ai's MARS HTTP API. + Inherits all common HTTP TTS functionality from AsyncTTS2HttpExtension. + """ + + def __init__(self, name: str) -> None: + super().__init__(name) + # Type hints for better IDE support + self.config: CambTTSConfig = None + self.client: CambTTSClient = None + + # ============================================================ + # Required method implementations + # ============================================================ + + async def create_config(self, config_json_str: str) -> AsyncTTS2HttpConfig: + """Create Camb TTS configuration from JSON string.""" + return CambTTSConfig.model_validate_json(config_json_str) + + async def create_client( + self, config: AsyncTTS2HttpConfig, ten_env: AsyncTenEnv + ) -> AsyncTTS2HttpClient: + """Create Camb TTS client.""" + return CambTTSClient(config=config, ten_env=ten_env) + + def vendor(self) -> str: + """Return vendor name.""" + return "camb" + + def synthesize_audio_sample_rate(self) -> int: + """Return the sample rate for synthesized audio. + + Returns model-specific sample rate: + - mars-flash: 22050 Hz + - mars-pro: 48000 Hz + - mars-instruct: 22050 Hz + """ + if self.client: + return self.client.get_sample_rate() + # Fallback to default model's sample rate + return MODEL_SAMPLE_RATES.get(DEFAULT_MODEL, 22050) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/manifest.json b/ai_agents/agents/ten_packages/extension/camb_tts/manifest.json new file mode 100644 index 0000000000..28ce9a8b27 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/manifest.json @@ -0,0 +1,61 @@ +{ + "type": "extension", + "name": "camb_tts", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.tent", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "voice_id": { + "type": "int32" + }, + "language": { + "type": "string" + }, + "speech_model": { + "type": "string" + }, + "user_instructions": { + "type": "string" + }, + "format": { + "type": "string" + } + } + } + } + } + } +} diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/property.json b/ai_agents/agents/ten_packages/extension/camb_tts/property.json new file mode 100644 index 0000000000..4c571bc993 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/property.json @@ -0,0 +1,11 @@ +{ + "dump": false, + "dump_path": "./", + "params": { + "api_key": "${env:CAMB_API_KEY|}", + "voice_id": 147320, + "language": "en-us", + "speech_model": "mars-flash", + "format": "pcm_s16le" + } +} diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/requirements.txt b/ai_agents/agents/ten_packages/extension/camb_tts/requirements.txt new file mode 100644 index 0000000000..7a519af7e9 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/requirements.txt @@ -0,0 +1,2 @@ +pydantic>=2.0.0 +aiohttp>=3.8.0 diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/tests/__init__.py b/ai_agents/agents/ten_packages/extension/camb_tts/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/tests/bin/start b/ai_agents/agents/ten_packages/extension/camb_tts/tests/bin/start new file mode 100755 index 0000000000..b736ea0de1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest -s tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/tests/conftest.py b/ai_agents/agents/ten_packages/extension/camb_tts/tests/conftest.py new file mode 100644 index 0000000000..bba09dde4b --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/tests/conftest.py @@ -0,0 +1,96 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +import threading +from typing_extensions import override +import pytest +from ten_runtime import ( + App, + TenEnv, +) + + +class FakeApp(App): + def __init__(self): + super().__init__() + self.event: threading.Event | None = None + + # In the case of a fake app, we use `on_init` to allow the blocked testing + # fixture to continue execution, rather than using `on_configure`. The + # reason is that in the TEN runtime C core, the relationship between the + # addon manager and the (fake) app is bound after `on_configure_done` is + # called. So we only need to let the testing fixture continue execution + # after this action in the TEN runtime C core, and at the upper layer + # timing, the earliest point is within the `on_init()` function of the upper + # TEN app. Therefore, we release the testing fixture lock within the user + # layer's `on_init()` of the TEN app. + def on_init(self, ten_env: TenEnv) -> None: + assert self.event + self.event.set() + + ten_env.on_init_done() + + @override + def on_configure(self, ten_env: TenEnv) -> None: + ten_env.init_property_from_json( + json.dumps( + { + "ten": { + "log": { + "handlers": [ + { + "matchers": [{"level": "debug"}], + "formatter": { + "type": "plain", + "colored": True, + }, + "emitter": { + "type": "console", + "config": {"stream": "stdout"}, + }, + } + ] + } + } + } + ), + ) + + ten_env.on_configure_done() + + +class FakeAppCtx: + def __init__(self, event: threading.Event): + self.fake_app: FakeApp | None = None + self.event = event + + +def run_fake_app(fake_app_ctx: FakeAppCtx): + app = FakeApp() + app.event = fake_app_ctx.event + fake_app_ctx.fake_app = app + app.run(False) + + +@pytest.fixture(scope="session", autouse=True) +def global_setup_and_teardown(): + event = threading.Event() + fake_app_ctx = FakeAppCtx(event) + + fake_app_thread = threading.Thread(target=run_fake_app, args=(fake_app_ctx,)) + fake_app_thread.start() + + event.wait() + + assert fake_app_ctx.fake_app is not None + + # Yield control to the test; after the test execution is complete, continue + # with the teardown process. + yield + + # Teardown part. + fake_app_ctx.fake_app.close() + fake_app_thread.join() diff --git a/ai_agents/agents/ten_packages/extension/camb_tts/tests/test_basic.py b/ai_agents/agents/ten_packages/extension/camb_tts/tests/test_basic.py new file mode 100644 index 0000000000..0a80757947 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts/tests/test_basic.py @@ -0,0 +1,325 @@ +import sys +from pathlib import Path + +# Add project root to sys.path to allow running tests from this directory +# The project root is 6 levels up from the parent directory of this file. +project_root = str(Path(__file__).resolve().parents[6]) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from pathlib import Path +import json +from unittest.mock import patch, AsyncMock +import os +import asyncio +import filecmp +import shutil +import threading + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput, TTSFlush, TTS2HttpResponseEventType + + +# ================ test dump file functionality ================ +class ExtensionTesterDump(ExtensionTester): + def __init__(self): + super().__init__() + # Use a fixed path as requested by the user. + self.dump_dir = "./dump/" + # Use a unique name for the file generated by the test to avoid collision + # with the file generated by the extension. + self.test_dump_file_path = os.path.join( + self.dump_dir, "test_manual_dump.pcm" + ) + self.audio_end_received = False + self.received_audio_chunks = [] + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request.""" + ten_env_tester.log_info("Dump test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_1", + text="Hello from Camb AI, this is a test", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + """Receives audio frames and collects their data using the lock/unlock pattern.""" + buf = audio_frame.lock_buf() + try: + copied_data = bytes(buf) + self.received_audio_chunks.append(copied_data) + finally: + audio_frame.unlock_buf(buf) + + def write_test_dump_file(self): + """Writes the collected audio chunks to a file.""" + with open(self.test_dump_file_path, "wb") as f: + for chunk in self.received_audio_chunks: + f.write(chunk) + + def find_tts_dump_file(self) -> str | None: + """Find the dump file created by the TTS extension in the fixed dump directory.""" + if not os.path.exists(self.dump_dir): + return None + for filename in os.listdir(self.dump_dir): + if filename.endswith(".pcm") and filename != os.path.basename( + self.test_dump_file_path + ): + return os.path.join(self.dump_dir, filename) + return None + + +@patch("camb_tts.extension.CambTTSClient") +def test_dump_functionality(MockCambTTSClient): + """Tests that the dump file from the TTS extension matches the audio received by the test extension.""" + print("Starting test_dump_functionality with mock...") + + # --- Directory Setup --- + DUMP_PATH = "./dump/" + + # Clean up directory before the test, in case of previous failed runs. + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + os.makedirs(DUMP_PATH) + + # --- Mock Configuration --- + mock_instance = MockCambTTSClient.return_value + mock_instance.clean = AsyncMock() + + # Create some fake audio data to be streamed + fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20 + fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20 + + # This async generator simulates the TTS client's get() method + async def mock_get_audio_stream(text: str, request_id: str | None = None): + yield (fake_audio_chunk_1, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.01) + yield (fake_audio_chunk_2, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.01) + yield (None, TTS2HttpResponseEventType.END) + + mock_instance.get.side_effect = mock_get_audio_stream + + # --- Test Setup --- + tester = ExtensionTesterDump() + + dump_config = { + "dump": True, + "dump_path": DUMP_PATH, + "params": { + "api_key": "test_api_key", + }, + } + + tester.set_test_mode_single("camb_tts", json.dumps(dump_config)) + + print("Running dump test...") + tester.run() + print("Dump test completed.") + + # --- Verification --- + assert tester.audio_end_received, "Expected to receive tts_audio_end" + assert ( + len(tester.received_audio_chunks) > 0 + ), "Expected to receive audio chunks" + + tester.write_test_dump_file() + + tts_dump_file = tester.find_tts_dump_file() + assert ( + tts_dump_file is not None + ), f"Expected to find a TTS dump file in {DUMP_PATH}" + assert os.path.exists( + tts_dump_file + ), f"TTS dump file should exist: {tts_dump_file}" + + print( + f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}" + ) + assert filecmp.cmp( + tester.test_dump_file_path, tts_dump_file, shallow=False + ), "Test dump file and TTS dump file should have the same content" + + print( + f"Dump functionality test passed: received {len(tester.received_audio_chunks)} audio chunks" + ) + print(f" Test file: {tester.test_dump_file_path}") + print(f" TTS dump file: {tts_dump_file}") + + # --- Cleanup --- + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + + +# ================ test flush logic ================ +class ExtensionTesterFlush(ExtensionTester): + def __init__(self): + super().__init__() + self.ten_env: TenEnvTester | None = None + self.audio_start_received = False + self.first_audio_frame_received = False + self.flush_start_received = False + self.audio_end_received = False + self.flush_end_received = False + self.audio_end_reason = "" + self.total_audio_duration_from_event = 0 + self.received_audio_bytes = 0 + self.sample_rate = 22050 # Camb TTS sample rate (mars-flash default) + self.bytes_per_sample = 2 # 16-bit + self.channels = 1 + self.audio_received_after_flush_end = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + self.ten_env = ten_env_tester + ten_env_tester.log_info("Flush test started, sending long TTS request.") + tts_input = TTSTextInput( + request_id="tts_request_for_flush", + text="This is a very long text designed to generate a continuous stream of audio, providing enough time to send a flush command.", + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + if self.flush_end_received: + ten_env.log_error("Received audio frame after tts_flush_end!") + self.audio_received_after_flush_end = True + + if not self.first_audio_frame_received: + self.first_audio_frame_received = True + ten_env.log_info("First audio frame received, sending flush data.") + flush_data = Data.create("tts_flush") + flush_data.set_property_from_json( + None, + TTSFlush(flush_id="tts_request_for_flush").model_dump_json(), + ) + ten_env.send_data(flush_data) + + buf = audio_frame.lock_buf() + try: + self.received_audio_bytes += len(buf) + finally: + audio_frame.unlock_buf(buf) + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "tts_audio_start": + self.audio_start_received = True + return + + json_str, _ = data.get_property_to_json(None) + if not json_str: + return + payload = json.loads(json_str) + ten_env.log_info(f"on_data payload: {payload}") + + if name == "tts_flush_start": + self.flush_start_received = True + return + + if name == "tts_audio_end": + self.audio_end_received = True + self.audio_end_reason = payload.get("reason") + self.total_audio_duration_from_event = payload.get( + "request_total_audio_duration_ms" + ) + + elif name == "tts_flush_end": + self.flush_end_received = True + + def stop_test_later(): + ten_env.log_info("Waited after flush_end, stopping test now.") + ten_env.stop_test() + + timer = threading.Timer(0.5, stop_test_later) + timer.start() + + def get_calculated_audio_duration_ms(self) -> int: + duration_sec = self.received_audio_bytes / ( + self.sample_rate * self.bytes_per_sample * self.channels + ) + return int(duration_sec * 1000) + + +@patch("camb_tts.extension.CambTTSClient") +def test_flush_logic(MockCambTTSClient): + """ + Tests that sending a flush command during TTS streaming correctly stops + the audio and sends the appropriate events. + """ + print("Starting test_flush_logic with mock...") + + mock_instance = MockCambTTSClient.return_value + mock_instance.clean = AsyncMock() + mock_instance.cancel = AsyncMock() + + async def mock_get_long_audio_stream( + text: str, request_id: str | None = None + ): + for _ in range(20): + if mock_instance.cancel.called: + print("Mock detected cancel call, sending EVENT_TTS_FLUSH.") + yield (None, TTS2HttpResponseEventType.FLUSH) + return + yield (b"\x11\x22\x33" * 100, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.1) + + yield (None, TTS2HttpResponseEventType.END) + + mock_instance.get.side_effect = mock_get_long_audio_stream + + config = { + "params": { + "api_key": "test_api_key", + }, + } + tester = ExtensionTesterFlush() + tester.set_test_mode_single("camb_tts", json.dumps(config)) + + print("Running flush logic test...") + tester.run() + print("Flush logic test completed.") + + assert tester.audio_start_received, "Did not receive tts_audio_start." + assert tester.first_audio_frame_received, "Did not receive any audio frame." + assert tester.audio_end_received, "Did not receive tts_audio_end." + assert tester.flush_end_received, "Did not receive tts_flush_end." + assert ( + not tester.audio_received_after_flush_end + ), "Received audio after tts_flush_end." + + calculated_duration = tester.get_calculated_audio_duration_ms() + event_duration = tester.total_audio_duration_from_event + print( + f"calculated_duration: {calculated_duration}, event_duration: {event_duration}" + ) + assert ( + abs(calculated_duration - event_duration) < 10 + ), f"Mismatch in audio duration. Calculated: {calculated_duration}ms, From event: {event_duration}ms" + + print("Flush logic test passed successfully.")