diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index f3d13db..33d88b0 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -94,10 +94,13 @@ def parse_args() -> argparse.Namespace: def main() -> None: - def request(inputs,args): + server_error = False + + def request(inputs, args): + nonlocal server_error try: dnt_phrases_input = {} - if args.dnt_phrases_file != None: + if args.dnt_phrases_file is not None: dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file) response = nmt_client.translate( texts=inputs, @@ -112,14 +115,15 @@ def request(inputs,args): print(translation.text) except grpc.RpcError as e: if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - result = {'msg': 'invalid arg error'} + msg = 'invalid arg error' elif e.code() == grpc.StatusCode.ALREADY_EXISTS: - result = {'msg': 'already exists error'} + msg = 'already exists error' elif e.code() == grpc.StatusCode.UNAVAILABLE: - result = {'msg': 'server unavailable check network'} + msg = 'server unavailable check network' else: - result = {'msg': 'error code:{}'.format(e.code())} - print(f"{result['msg']} : {e.details()}") + msg = f'error code:{e.code()}' + print(f"{msg} : {e.details()}", file=sys.stderr) + server_error = True args = parse_args() @@ -135,12 +139,19 @@ def request(inputs,args): nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: - - response = nmt_client.get_config(args.model_name) + try: + response = nmt_client.get_config(args.model_name) + except grpc.RpcError as e: + print(f"Failed to list models: {e.code()} : {e.details()}", file=sys.stderr) + sys.exit(1) print(response) return - if args.text_file != None and os.path.exists(args.text_file): + if args.text_file is not None: + if not os.path.exists(args.text_file): + print(f"--text-file path does not exist: {args.text_file}", file=sys.stderr) + sys.exit(2) + translated_any = False with open(args.text_file, "r") as f: batch = [] for line in f: @@ -149,13 +160,24 @@ def request(inputs,args): batch.append(line) if len(batch) == args.batch_size: request(batch, args) + translated_any = True batch = [] if len(batch) > 0: request(batch, args) + translated_any = True + if not translated_any: + print(f"--text-file {args.text_file} contained no non-empty lines", file=sys.stderr) + sys.exit(2) + if server_error: + sys.exit(1) return - if args.text != "": - request([args.text], args) + if args.text == "": + print("--text must not be empty (or provide --text-file)", file=sys.stderr) + sys.exit(2) + request([args.text], args) + if server_error: + sys.exit(1) if __name__ == '__main__': diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 7348525..e70f148 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -1,5 +1,6 @@ import argparse import os +import sys import wave from typing import Iterator @@ -30,7 +31,8 @@ def main(): # Validate input file if not os.path.exists(args.audio_file): - raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") + print(f"Input audio file not found: {args.audio_file}", file=sys.stderr) + sys.exit(2) auth = riva.client.Auth( ssl_root_cert=args.ssl_root_cert, @@ -44,7 +46,11 @@ def main(): nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: - response = nmt_client.get_config() + try: + response = nmt_client.get_config() + except grpc.RpcError as e: + print(f"Failed to list models: {e.code()} : {e.details()}", file=sys.stderr) + sys.exit(1) print(response) return @@ -99,7 +105,8 @@ def main(): output_file.close() except Exception as e: - print(f"Error during translation: {e}") + print(f"Error during translation: {e}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__": diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index bc8c0f2..724bd3a 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -1,5 +1,9 @@ import argparse import os +import sys + +import grpc + import riva.client import riva.client.proto.riva_asr_pb2 as riva_asr_pb2 import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2 @@ -41,7 +45,8 @@ def main(): # Validate input file if not os.path.exists(args.audio_file): - raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") + print(f"Input audio file not found: {args.audio_file}", file=sys.stderr) + sys.exit(2) auth = riva.client.Auth( ssl_root_cert=args.ssl_root_cert, @@ -55,7 +60,11 @@ def main(): nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: - response = nmt_client.get_config(args.model) + try: + response = nmt_client.get_config(args.model) + except grpc.RpcError as e: + print(f"Failed to list models: {e.code()} : {e.details()}", file=sys.stderr) + sys.exit(1) print(response) return @@ -101,7 +110,8 @@ def main(): print(f"Final translation: {final_translation}") except Exception as e: - print(f"Error during translation: {e}") + print(f"Error during translation: {e}", file=sys.stderr) + sys.exit(1) if __name__ == "__main__":