diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index a8cc1a7..ae0772b 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -2,6 +2,59 @@ # SPDX-License-Identifier: MIT import argparse +import functools +import sys + +import grpc + +# Exit codes shared by the CLI scripts. Pipelines that compose these scripts +# rely on a non-zero status to detect failure; see also `cli_main` below. +EXIT_OK = 0 +EXIT_GENERIC_ERROR = 1 +EXIT_BAD_INPUT = 2 # malformed args, missing file, empty/whitespace text, ... +EXIT_UNAVAILABLE = 3 # gRPC UNAVAILABLE (server down, wrong port, ...) +EXIT_INVALID_ARGUMENT = 4 # gRPC INVALID_ARGUMENT or NOT_FOUND (bad model/lang/voice) +EXIT_INTERRUPTED = 130 # SIGINT + + +def _grpc_exit_code(error: grpc.RpcError) -> int: + code = error.code() if callable(getattr(error, "code", None)) else None + if code == grpc.StatusCode.UNAVAILABLE: + return EXIT_UNAVAILABLE + if code in (grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND): + return EXIT_INVALID_ARGUMENT + return EXIT_GENERIC_ERROR + + +def cli_main(func): + """Translate exceptions raised by a CLI ``main`` into consistent exit codes. + + Wrapped function may return an int exit code or ``None`` (treated as + ``EXIT_OK``). Unhandled exceptions are caught and mapped: gRPC ``RpcError`` + via status code, ``FileNotFoundError`` / ``ValueError`` → ``EXIT_BAD_INPUT``, + anything else → ``EXIT_GENERIC_ERROR``. The error is also printed to stderr + so CI logs surface the cause. + """ + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + result = func(*args, **kwargs) + return EXIT_OK if result is None else int(result) + except KeyboardInterrupt: + return EXIT_INTERRUPTED + except grpc.RpcError as e: + details = e.details() if callable(getattr(e, "details", None)) else str(e) + print(f"Error: {details}", file=sys.stderr) + return _grpc_exit_code(e) + except (FileNotFoundError, IsADirectoryError, ValueError) as e: + print(f"Error: {e}", file=sys.stderr) + return EXIT_BAD_INPUT + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return EXIT_GENERIC_ERROR + + return wrapper + def validate_grpc_message_size(value): """Validate that the GRPC message size is within acceptable limits.""" diff --git a/riva/client/audio_io.py b/riva/client/audio_io.py index ea43279..0618158 100644 --- a/riva/client/audio_io.py +++ b/riva/client/audio_io.py @@ -4,7 +4,17 @@ import queue from typing import Dict, Union, Optional -import pyaudio + +def _require_pyaudio(): + try: + import pyaudio + return pyaudio + except ImportError as e: + raise ImportError( + "pyaudio is required for audio device I/O. Install the system PortAudio " + "headers first (e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`." + ) from e class MicrophoneStream: @@ -20,6 +30,8 @@ def __init__(self, rate: int, chunk: int, device: int = None) -> None: self.closed = True def __enter__(self): + pyaudio = _require_pyaudio() + self._pa_module = pyaudio self._audio_interface = pyaudio.PyAudio() self._audio_stream = self._audio_interface.open( format=pyaudio.paInt16, @@ -50,7 +62,7 @@ def __exit__(self, type, value, traceback): def _fill_buffer(self, in_data, frame_count, time_info, status_flags): """Continuously collect data from the audio stream into the buffer.""" self._buff.put(in_data) - return None, pyaudio.paContinue + return None, self._pa_module.paContinue def __next__(self) -> bytes: if self.closed: @@ -76,6 +88,7 @@ def __iter__(self): def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() info = p.get_device_info_by_index(device_id) p.terminate() @@ -83,6 +96,7 @@ def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]: def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]]]: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() try: info = p.get_default_input_device_info() @@ -93,6 +107,7 @@ def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str] def list_output_devices() -> None: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() print("Output audio devices:") for i in range(p.get_device_count()): @@ -104,6 +119,7 @@ def list_output_devices() -> None: def list_input_devices() -> None: + pyaudio = _require_pyaudio() p = pyaudio.PyAudio() print("Input audio devices:") for i in range(p.get_device_count()): @@ -118,6 +134,7 @@ class SoundCallBack: def __init__( self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int, ) -> None: + pyaudio = _require_pyaudio() self.pa = pyaudio.PyAudio() self.stream = self.pa.open( output_device_index=output_device_index, diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index 172ec9b..0ffd115 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -12,6 +12,7 @@ add_asr_config_argparse_parameters, add_realtime_config_argparse_parameters, add_connection_argparse_parameters, + cli_main, ) @@ -300,17 +301,22 @@ async def main() -> None: import riva.client.audio_io riva.client.audio_io.list_input_devices() except ModuleNotFoundError: - print("PyAudio not available. Please install PyAudio to list audio devices.") + print( + "PyAudio not available. Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev`), then `pip install pyaudio`.", + file=sys.stderr, + ) return setup_signal_handler() + await run_transcription(args) - try: - await run_transcription(args) - except Exception as e: - print(f"Fatal error: {e}") - sys.exit(1) + +@cli_main +def _entry() -> int: + asyncio.run(main()) + return 0 if __name__ == "__main__": - asyncio.run(main()) + sys.exit(_entry()) diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index f600af6..08a6ea2 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -4,6 +4,7 @@ import argparse import os import queue +import sys import time from pathlib import Path from threading import Thread @@ -11,7 +12,11 @@ import riva.client from riva.client.asr import get_wav_file_parameters -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, +) def parse_args() -> argparse.Namespace: @@ -109,7 +114,8 @@ def streaming_transcription_worker( raise -def main() -> None: +@cli_main +def main() -> int: args = parse_args() print("Number of clients:", args.num_clients) print("Number of iteration:", args.num_iterations) @@ -140,4 +146,4 @@ def main() -> None: if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index 1849a67..8c3f6c2 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -2,10 +2,16 @@ # SPDX-License-Identifier: MIT import argparse - import os +import sys + import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) def parse_args() -> argparse.Namespace: @@ -61,7 +67,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.list_devices: riva.client.audio_io.list_output_devices() @@ -95,8 +102,8 @@ def main() -> None: return if not os.path.isfile(args.input_file): - print(f"Invalid input file path: {args.input_file}") - return + print(f"Invalid input file path: {args.input_file}", file=sys.stderr) + return EXIT_BAD_INPUT config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( @@ -159,4 +166,4 @@ def main() -> None: if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 92634ca..4c2767a 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -2,12 +2,17 @@ # SPDX-License-Identifier: MIT import os +import sys import argparse from pathlib import Path -import grpc import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) def parse_args() -> argparse.Namespace: @@ -30,7 +35,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] @@ -62,8 +68,8 @@ def main() -> None: return if not os.path.isfile(args.input_file): - print(f"Invalid input file path: {args.input_file}") - return + print(f"Invalid input file path: {args.input_file}", file=sys.stderr) + return EXIT_BAD_INPUT config = riva.client.RecognitionConfig( language_code=args.language_code, @@ -91,14 +97,15 @@ def main() -> None: ) with args.input_file.open('rb') as fh: data = fh.read() - try: - seglst_output_file = None - if args.output_seglst: - seglst_output_file = os.path.basename(args.input_file).split(".")[0] - riva.client.print_offline(response=asr_service.offline_recognize(data, config), speaker_diarization=args.speaker_diarization, seglst_output_file=seglst_output_file) - except grpc.RpcError as e: - print(e.details()) + seglst_output_file = None + if args.output_seglst: + seglst_output_file = os.path.basename(args.input_file).split(".")[0] + riva.client.print_offline( + response=asr_service.offline_recognize(data, config), + speaker_diarization=args.speaker_diarization, + seglst_output_file=seglst_output_file, + ) if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 3fd2b5a..a5d6326 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -2,16 +2,27 @@ # SPDX-License-Identifier: MIT import argparse +import sys import riva.client -from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_asr_config_argparse_parameters, + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) try: import riva.client.audio_io except ModuleNotFoundError as e: - print(f"ModuleNotFoundError: {e}") - print("Please install pyaudio from https://pypi.org/project/PyAudio") - exit(1) + print(f"ModuleNotFoundError: {e}", file=sys.stderr) + print( + "Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`.", + file=sys.stderr, + ) + sys.exit(EXIT_BAD_INPUT) def parse_args() -> argparse.Namespace: default_device_info = riva.client.audio_io.get_default_input_device_info() @@ -40,7 +51,8 @@ def parse_args() -> argparse.Namespace: return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.list_devices: riva.client.audio_io.list_input_devices() @@ -98,4 +110,4 @@ def main() -> None: if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index f3d13db..d77a49d 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -30,12 +30,11 @@ import os import sys -import grpc import riva.client.proto.riva_nmt_pb2 as riva_nmt import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main, EXIT_BAD_INPUT def read_dnt_phrases_file(file_path): @@ -78,7 +77,15 @@ def parse_args() -> argparse.Namespace: ) inputs.add_argument("--text-file", type=str, help="Path to file for translation") parser.add_argument("--dnt-phrases-file", type=str, help="Path to file which contains dnt phrases and custom translations") - parser.add_argument("--max-len-variation", type=str, help="Parameter to control the maximum variation between the length of source and translated text in terms of tokens") + parser.add_argument( + "--max-len-variation", + type=str, + help="Maximum allowed difference (in decoder SentencePiece tokens, not characters) " + "between the source and translated text length. Valid range: [0, 256]. Server-side " + "default is 20. Increase this for long inputs that get truncated; high-aspect-ratio " + "languages like Arabic may need additional client-side chunking when the source " + "exceeds ~200 characters even at 256.", + ) parser.add_argument("--model-name", default="", type=str, help="model to use to translate") parser.add_argument( "--source-language-code", type=str, default="en-US", help="Source language code (according to BCP-47 standard)" @@ -93,34 +100,8 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def main() -> None: - def request(inputs,args): - try: - dnt_phrases_input = {} - if args.dnt_phrases_file != None: - dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file) - response = nmt_client.translate( - texts=inputs, - model=args.model_name, - source_language=args.source_language_code, - target_language=args.target_language_code, - future=False, - dnt_phrases_dict=dnt_phrases_input, - max_len_variation=args.max_len_variation, - ) - for translation in response.translations: - print(translation.text) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - result = {'msg': 'invalid arg error'} - elif e.code() == grpc.StatusCode.ALREADY_EXISTS: - result = {'msg': 'already exists error'} - elif e.code() == grpc.StatusCode.UNAVAILABLE: - result = {'msg': 'server unavailable check network'} - else: - result = {'msg': 'error code:{}'.format(e.code())} - print(f"{result['msg']} : {e.details()}") - +@cli_main +def main() -> int: args = parse_args() auth = riva.client.Auth( @@ -134,13 +115,31 @@ def request(inputs,args): ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) - if args.list_models: + def request(inputs): + dnt_phrases_input = {} + if args.dnt_phrases_file is not None: + dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file) + response = nmt_client.translate( + texts=inputs, + model=args.model_name, + source_language=args.source_language_code, + target_language=args.target_language_code, + future=False, + dnt_phrases_dict=dnt_phrases_input, + max_len_variation=args.max_len_variation, + ) + for translation in response.translations: + print(translation.text) + if args.list_models: response = nmt_client.get_config(args.model_name) print(response) return - if args.text_file != None and os.path.exists(args.text_file): + if args.text_file is not None: + if not os.path.exists(args.text_file): + print(f"Invalid input file path: {args.text_file}", file=sys.stderr) + return EXIT_BAD_INPUT with open(args.text_file, "r") as f: batch = [] for line in f: @@ -148,15 +147,17 @@ def request(inputs,args): if line != "": batch.append(line) if len(batch) == args.batch_size: - request(batch, args) + request(batch) batch = [] if len(batch) > 0: - request(batch, args) + request(batch) return - if args.text != "": - request([args.text], args) + if not args.text or not args.text.strip(): + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT + request([args.text]) if __name__ == '__main__': - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 7348525..20f711d 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -1,14 +1,13 @@ import argparse import os +import sys import wave from typing import Iterator -import grpc - import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main def parse_arguments(): @@ -25,6 +24,7 @@ def parse_arguments(): return parser.parse_args() +@cli_main def main(): args = parse_arguments() @@ -48,59 +48,54 @@ def main(): print(response) return + print(f"Translating speech from {args.source_language} to {args.target_language}") + print(f"Using audio file: {args.audio_file}") + print(f"Server address: {args.server}") + + # Create ASR config + asr_config = riva_asr_pb2.StreamingRecognitionConfig( + config=riva_asr_pb2.RecognitionConfig( + language_code=args.source_language, max_alternatives=1, enable_automatic_punctuation=True + ), + interim_results=True, + ) + + # Create translation config + translation_config = riva_nmt_pb2.TranslationConfig( + source_language_code=args.source_language, target_language_code=args.target_language, + ) + + # Create synthesis config + tts_config = riva_nmt_pb2.SynthesizeSpeechConfig( + encoding=riva.client.AudioEncoding.LINEAR_PCM, + language_code=args.target_language, + voice_name=args.voice, + sample_rate_hz=args.sample_rate_hz, + ) + + # Create streaming config + streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig( + asr_config=asr_config, translation_config=translation_config, tts_config=tts_config + ) + + responses = nmt_client.streaming_s2s_response_generator( + audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), streaming_config=streaming_config + ) + + output_file = None try: - print(f"Translating speech from {args.source_language} to {args.target_language}") - print(f"Using audio file: {args.audio_file}") - print(f"Server address: {args.server}") - - # Create ASR config - asr_config = riva_asr_pb2.StreamingRecognitionConfig( - config=riva_asr_pb2.RecognitionConfig( - language_code=args.source_language, max_alternatives=1, enable_automatic_punctuation=True - ), - interim_results=True, - ) - - # Create translation config - translation_config = riva_nmt_pb2.TranslationConfig( - source_language_code=args.source_language, target_language_code=args.target_language, - ) - - # Create synthesis config - tts_config = riva_nmt_pb2.SynthesizeSpeechConfig( - encoding=riva.client.AudioEncoding.LINEAR_PCM, - language_code=args.target_language, - voice_name=args.voice, - sample_rate_hz=args.sample_rate_hz, - ) - - # Create streaming config - streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig( - asr_config=asr_config, translation_config=translation_config, tts_config=tts_config - ) - - responses = nmt_client.streaming_s2s_response_generator( - audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), streaming_config=streaming_config - ) - - try: - output_file = None - for response in responses: - if len(response.speech.audio) > 0 and output_file is None: - output_file = wave.open(str(args.output_file), 'wb') - output_file.setnchannels(1) - output_file.setsampwidth(2) - output_file.setframerate(args.sample_rate_hz) - output_file.writeframesraw(response.speech.audio) - - finally: - if output_file is not None: - print(f"Written {output_file.getnframes()} samples to {args.output_file}") - output_file.close() - - except Exception as e: - print(f"Error during translation: {e}") + for response in responses: + if len(response.speech.audio) > 0 and output_file is None: + output_file = wave.open(str(args.output_file), 'wb') + output_file.setnchannels(1) + output_file.setsampwidth(2) + output_file.setframerate(args.sample_rate_hz) + output_file.writeframesraw(response.speech.audio) + finally: + if output_file is not None: + print(f"Written {output_file.getnframes()} samples to {args.output_file}") + output_file.close() if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index bc8c0f2..66ea3da 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -1,9 +1,11 @@ import argparse import os +import sys + import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main def parse_arguments(): parser = argparse.ArgumentParser(description='Riva Speech-to-Text Translation Client') @@ -36,6 +38,7 @@ def parse_arguments(): return parser.parse_args() +@cli_main def main(): args = parse_arguments() @@ -59,50 +62,46 @@ def main(): print(response) return - try: - print(f"Translating speech from {args.source_language} to {args.target_language}") - print(f"Using audio file: {args.audio_file}") - print(f"Server address: {args.server}") - - # Create ASR config - asr_config = riva_asr_pb2.StreamingRecognitionConfig( - config=riva_asr_pb2.RecognitionConfig( - language_code=args.source_language, - max_alternatives=1, - enable_automatic_punctuation=True - ), - interim_results=True - ) - - # Create translation config - translation_config = riva_nmt_pb2.TranslationConfig( - source_language_code=args.source_language, - target_language_code=args.target_language, - model_name=args.model - ) - - # Create streaming config - streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToTextConfig( - asr_config=asr_config, - translation_config=translation_config - ) - - responses = nmt_client.streaming_s2t_response_generator( - audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), - streaming_config=streaming_config - ) - - final_translation = "" - for response in responses: - for result in response.results: - if result.is_final: - final_translation += result.alternatives[0].transcript - - print(f"Final translation: {final_translation}") - - except Exception as e: - print(f"Error during translation: {e}") + print(f"Translating speech from {args.source_language} to {args.target_language}") + print(f"Using audio file: {args.audio_file}") + print(f"Server address: {args.server}") + + # Create ASR config + asr_config = riva_asr_pb2.StreamingRecognitionConfig( + config=riva_asr_pb2.RecognitionConfig( + language_code=args.source_language, + max_alternatives=1, + enable_automatic_punctuation=True + ), + interim_results=True + ) + + # Create translation config + translation_config = riva_nmt_pb2.TranslationConfig( + source_language_code=args.source_language, + target_language_code=args.target_language, + model_name=args.model + ) + + # Create streaming config + streaming_config = riva_nmt_pb2.StreamingTranslateSpeechToTextConfig( + asr_config=asr_config, + translation_config=translation_config + ) + + responses = nmt_client.streaming_s2t_response_generator( + audio_chunks=riva.client.AudioChunkFileIterator(args.audio_file, 100), + streaming_config=streaming_config + ) + + final_translation = "" + for response in responses: + for result in response.results: + if result.is_final: + final_translation += result.alternatives[0].transcript + + print(f"Final translation: {final_translation}") if __name__ == "__main__": - main() \ No newline at end of file + sys.exit(main()) \ No newline at end of file diff --git a/scripts/tts/realtime_tts_client.py b/scripts/tts/realtime_tts_client.py index 096c0dc..5277c8b 100644 --- a/scripts/tts/realtime_tts_client.py +++ b/scripts/tts/realtime_tts_client.py @@ -20,8 +20,7 @@ import websockets from websockets.exceptions import WebSocketException -from riva.client.argparse_utils import add_connection_argparse_parameters -from riva.client.audio_io import SoundCallBack +from riva.client.argparse_utils import add_connection_argparse_parameters, cli_main from riva.client.realtime import RealtimeClientTTS @@ -478,30 +477,30 @@ async def process_single_text(text_idx, text_line): async def main() -> int: """Main entry point for the realtime TTS client.""" args = parse_args() - success = False + success = True setup_signal_handler() - try: - if args.list_voices: - voices = RealtimeClientTTS(args=args).list_voices() - print(json.dumps(voices, indent=4)) - elif args.list_devices: - import riva.client.audio_io - riva.client.audio_io.list_output_devices() + if args.list_voices: + voices = RealtimeClientTTS(args=args).list_voices() + print(json.dumps(voices, indent=4)) + elif args.list_devices: + import riva.client.audio_io + riva.client.audio_io.list_output_devices() + else: + # Use parallel processing if num_parallel_requests > 1 + if args.num_parallel_requests > 1: + logger.info(f"Using parallel processing mode with {args.num_parallel_requests} concurrent requests") + success = await run_parallel_synthesis(args) else: - # Use parallel processing if num_parallel_requests > 1 - if args.num_parallel_requests > 1: - logger.info(f"Using parallel processing mode with {args.num_parallel_requests} concurrent requests") - success = await run_parallel_synthesis(args) - else: - logger.info("Using single request mode") - success = await run_synthesis(args) - return 0 if success else 1 - except Exception as e: - logger.error("Fatal error: %s", e) - return 1 + logger.info("Using single request mode") + success = await run_synthesis(args) + return 0 if success else 1 + + +@cli_main +def _entry() -> int: + return asyncio.run(main()) if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) + sys.exit(_entry()) diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 9ac0ea5..d12d45e 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -2,13 +2,18 @@ # SPDX-License-Identifier: MIT import argparse +import sys import time import wave import json from pathlib import Path import riva.client -from riva.client.argparse_utils import add_connection_argparse_parameters +from riva.client.argparse_utils import ( + add_connection_argparse_parameters, + cli_main, + EXIT_BAD_INPUT, +) from riva.client.proto.riva_audio_pb2 import AudioEncoding from riva.client.tts import parse_custom_configuration @@ -94,16 +99,21 @@ def parse_args() -> argparse.Namespace: import riva.client.audio_io except ModuleNotFoundError as e: print(f"ModuleNotFoundError: {e}") - print("Please install pyaudio from https://pypi.org/project/PyAudio") + print( + "Install the system PortAudio headers first " + "(e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, " + "`brew install portaudio` on macOS), then `pip install pyaudio`." + ) exit(1) return args -def main() -> None: +@cli_main +def main() -> int: args = parse_args() if args.output.is_dir(): - print("Empty output file path not allowed") - return + print("Empty output file path not allowed", file=sys.stderr) + return EXIT_BAD_INPUT if args.list_devices: riva.client.audio_io.list_output_devices() return @@ -148,11 +158,14 @@ def main() -> None: return if not args.text and not args.text_file: - print("No input text provided") - return + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT + if args.text is not None and not args.text.strip(): + print("No input text provided", file=sys.stderr) + return EXIT_BAD_INPUT if args.text and args.text_file: - print("Cannot provide both text and text_file at the same time.") - return + print("Cannot provide both text and text_file at the same time.", file=sys.stderr) + return EXIT_BAD_INPUT try: if args.output_device is not None or args.play_audio: sound_stream = riva.client.audio_io.SoundCallBack( @@ -218,11 +231,6 @@ def main() -> None: sound_stream(resp.audio) if out_f is not None: out_f.writeframesraw(resp.audio) - except Exception as e: - if callable(getattr(e, "details", None)): - print(e.details()) - else: - print(e) finally: if out_f is not None: out_f.close() @@ -231,4 +239,4 @@ def main() -> None: if __name__ == '__main__': - main() + sys.exit(main())