Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,59 @@
# SPDX-License-Identifier: MIT

import argparse
import functools
import sys

import grpc

# Exit codes shared by the CLI scripts. Pipelines that compose these scripts
# rely on a non-zero status to detect failure; see also `cli_main` below.
EXIT_OK = 0
EXIT_GENERIC_ERROR = 1
EXIT_BAD_INPUT = 2 # malformed args, missing file, empty/whitespace text, ...
EXIT_UNAVAILABLE = 3 # gRPC UNAVAILABLE (server down, wrong port, ...)
EXIT_INVALID_ARGUMENT = 4 # gRPC INVALID_ARGUMENT or NOT_FOUND (bad model/lang/voice)
EXIT_INTERRUPTED = 130 # SIGINT


def _grpc_exit_code(error: grpc.RpcError) -> int:
code = error.code() if callable(getattr(error, "code", None)) else None
if code == grpc.StatusCode.UNAVAILABLE:
return EXIT_UNAVAILABLE
if code in (grpc.StatusCode.INVALID_ARGUMENT, grpc.StatusCode.NOT_FOUND):
return EXIT_INVALID_ARGUMENT
return EXIT_GENERIC_ERROR


def cli_main(func):
"""Translate exceptions raised by a CLI ``main`` into consistent exit codes.

Wrapped function may return an int exit code or ``None`` (treated as
``EXIT_OK``). Unhandled exceptions are caught and mapped: gRPC ``RpcError``
via status code, ``FileNotFoundError`` / ``ValueError`` → ``EXIT_BAD_INPUT``,
anything else → ``EXIT_GENERIC_ERROR``. The error is also printed to stderr
so CI logs surface the cause.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
return EXIT_OK if result is None else int(result)
except KeyboardInterrupt:
return EXIT_INTERRUPTED
except grpc.RpcError as e:
details = e.details() if callable(getattr(e, "details", None)) else str(e)
print(f"Error: {details}", file=sys.stderr)
return _grpc_exit_code(e)
except (FileNotFoundError, IsADirectoryError, ValueError) as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_BAD_INPUT
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return EXIT_GENERIC_ERROR

return wrapper


def validate_grpc_message_size(value):
"""Validate that the GRPC message size is within acceptable limits."""
Expand Down
21 changes: 19 additions & 2 deletions riva/client/audio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import queue
from typing import Dict, Union, Optional

import pyaudio

def _require_pyaudio():
try:
import pyaudio
return pyaudio
except ImportError as e:
raise ImportError(
"pyaudio is required for audio device I/O. Install the system PortAudio "
"headers first (e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, "
"`brew install portaudio` on macOS), then `pip install pyaudio`."
) from e


class MicrophoneStream:
Expand All @@ -20,6 +30,8 @@ def __init__(self, rate: int, chunk: int, device: int = None) -> None:
self.closed = True

def __enter__(self):
pyaudio = _require_pyaudio()
self._pa_module = pyaudio
self._audio_interface = pyaudio.PyAudio()
self._audio_stream = self._audio_interface.open(
format=pyaudio.paInt16,
Expand Down Expand Up @@ -50,7 +62,7 @@ def __exit__(self, type, value, traceback):
def _fill_buffer(self, in_data, frame_count, time_info, status_flags):
"""Continuously collect data from the audio stream into the buffer."""
self._buff.put(in_data)
return None, pyaudio.paContinue
return None, self._pa_module.paContinue

def __next__(self) -> bytes:
if self.closed:
Expand All @@ -76,13 +88,15 @@ def __iter__(self):


def get_audio_device_info(device_id: int) -> Dict[str, Union[int, float, str]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
info = p.get_device_info_by_index(device_id)
p.terminate()
return info


def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]]]:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
try:
info = p.get_default_input_device_info()
Expand All @@ -93,6 +107,7 @@ def get_default_input_device_info() -> Optional[Dict[str, Union[int, float, str]


def list_output_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Output audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -104,6 +119,7 @@ def list_output_devices() -> None:


def list_input_devices() -> None:
pyaudio = _require_pyaudio()
p = pyaudio.PyAudio()
print("Input audio devices:")
for i in range(p.get_device_count()):
Expand All @@ -118,6 +134,7 @@ class SoundCallBack:
def __init__(
self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,
) -> None:
pyaudio = _require_pyaudio()
self.pa = pyaudio.PyAudio()
self.stream = self.pa.open(
output_device_index=output_device_index,
Expand Down
20 changes: 13 additions & 7 deletions scripts/asr/realtime_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
add_asr_config_argparse_parameters,
add_realtime_config_argparse_parameters,
add_connection_argparse_parameters,
cli_main,
)


Expand Down Expand Up @@ -300,17 +301,22 @@ async def main() -> None:
import riva.client.audio_io
riva.client.audio_io.list_input_devices()
except ModuleNotFoundError:
print("PyAudio not available. Please install PyAudio to list audio devices.")
print(
"PyAudio not available. Install the system PortAudio headers first "
"(e.g. `apt-get install -y portaudio19-dev`), then `pip install pyaudio`.",
file=sys.stderr,
)
return

setup_signal_handler()
await run_transcription(args)

try:
await run_transcription(args)
except Exception as e:
print(f"Fatal error: {e}")
sys.exit(1)

@cli_main
def _entry() -> int:
asyncio.run(main())
return 0


if __name__ == "__main__":
asyncio.run(main())
sys.exit(_entry())
12 changes: 9 additions & 3 deletions scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import argparse
import os
import queue
import sys
import time
from pathlib import Path
from threading import Thread
from typing import Union

import riva.client
from riva.client.asr import get_wav_file_parameters
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.argparse_utils import (
add_asr_config_argparse_parameters,
add_connection_argparse_parameters,
cli_main,
)


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -109,7 +114,8 @@ def streaming_transcription_worker(
raise


def main() -> None:
@cli_main
def main() -> int:
args = parse_args()
print("Number of clients:", args.num_clients)
print("Number of iteration:", args.num_iterations)
Expand Down Expand Up @@ -140,4 +146,4 @@ def main() -> None:


if __name__ == "__main__":
main()
sys.exit(main())
19 changes: 13 additions & 6 deletions scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
# SPDX-License-Identifier: MIT

import argparse

import os
import sys

import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.argparse_utils import (
add_asr_config_argparse_parameters,
add_connection_argparse_parameters,
cli_main,
EXIT_BAD_INPUT,
)


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -61,7 +67,8 @@ def parse_args() -> argparse.Namespace:
return args


def main() -> None:
@cli_main
def main() -> int:
args = parse_args()
if args.list_devices:
riva.client.audio_io.list_output_devices()
Expand Down Expand Up @@ -95,8 +102,8 @@ def main() -> None:
return

if not os.path.isfile(args.input_file):
print(f"Invalid input file path: {args.input_file}")
return
print(f"Invalid input file path: {args.input_file}", file=sys.stderr)
return EXIT_BAD_INPUT

config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down Expand Up @@ -159,4 +166,4 @@ def main() -> None:


if __name__ == "__main__":
main()
sys.exit(main())
33 changes: 20 additions & 13 deletions scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
# SPDX-License-Identifier: MIT

import os
import sys
import argparse
from pathlib import Path

import grpc
import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.argparse_utils import (
add_asr_config_argparse_parameters,
add_connection_argparse_parameters,
cli_main,
EXIT_BAD_INPUT,
)


def parse_args() -> argparse.Namespace:
Expand All @@ -30,7 +35,8 @@ def parse_args() -> argparse.Namespace:
return args


def main() -> None:
@cli_main
def main() -> int:
args = parse_args()

options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)]
Expand Down Expand Up @@ -62,8 +68,8 @@ def main() -> None:
return

if not os.path.isfile(args.input_file):
print(f"Invalid input file path: {args.input_file}")
return
print(f"Invalid input file path: {args.input_file}", file=sys.stderr)
return EXIT_BAD_INPUT

config = riva.client.RecognitionConfig(
language_code=args.language_code,
Expand Down Expand Up @@ -91,14 +97,15 @@ def main() -> None:
)
with args.input_file.open('rb') as fh:
data = fh.read()
try:
seglst_output_file = None
if args.output_seglst:
seglst_output_file = os.path.basename(args.input_file).split(".")[0]
riva.client.print_offline(response=asr_service.offline_recognize(data, config), speaker_diarization=args.speaker_diarization, seglst_output_file=seglst_output_file)
except grpc.RpcError as e:
print(e.details())
seglst_output_file = None
if args.output_seglst:
seglst_output_file = os.path.basename(args.input_file).split(".")[0]
riva.client.print_offline(
response=asr_service.offline_recognize(data, config),
speaker_diarization=args.speaker_diarization,
seglst_output_file=seglst_output_file,
)


if __name__ == "__main__":
main()
sys.exit(main())
24 changes: 18 additions & 6 deletions scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
# SPDX-License-Identifier: MIT

import argparse
import sys

import riva.client
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
from riva.client.argparse_utils import (
add_asr_config_argparse_parameters,
add_connection_argparse_parameters,
cli_main,
EXIT_BAD_INPUT,
)

try:
import riva.client.audio_io
except ModuleNotFoundError as e:
print(f"ModuleNotFoundError: {e}")
print("Please install pyaudio from https://pypi.org/project/PyAudio")
exit(1)
print(f"ModuleNotFoundError: {e}", file=sys.stderr)
print(
"Install the system PortAudio headers first "
"(e.g. `apt-get install -y portaudio19-dev` on Debian/Ubuntu, "
"`brew install portaudio` on macOS), then `pip install pyaudio`.",
file=sys.stderr,
)
sys.exit(EXIT_BAD_INPUT)

def parse_args() -> argparse.Namespace:
default_device_info = riva.client.audio_io.get_default_input_device_info()
Expand Down Expand Up @@ -40,7 +51,8 @@ def parse_args() -> argparse.Namespace:
return args


def main() -> None:
@cli_main
def main() -> int:
args = parse_args()
if args.list_devices:
riva.client.audio_io.list_input_devices()
Expand Down Expand Up @@ -98,4 +110,4 @@ def main() -> None:


if __name__ == '__main__':
main()
sys.exit(main())
Loading