-
Notifications
You must be signed in to change notification settings - Fork 291
Add Nemotron-ASR streaming inference to Rust SDK #613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rui-ren
wants to merge
11
commits into
main
Choose a base branch
from
ruiren/live-audio-stream-rust
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
fd22350
Add live audio transcription streaming support to Rust SDK
ab41774
Add live audio transcription E2E sample for Rust SDK
2d6eb8c
Add real microphone support to live audio transcription sample
5862384
Fix FFI null pointer and native session leak in Drop
1b23343
Improve API parity with C# LiveAudioTranscription
d8459b2
Update codex-feedback.md: mark parity gaps as resolved
0f8ae7a
Fix CI: update download callback to f64 and apply cargo fmt
d358aa6
Fix CI: remove unused setup_audio_client from live_audio_test
60c6353
Address PR review feedback
d3d4334
Fix clippy needless_return for Rust 1.94
36f34f8
Fix cargo fmt: collapse single-arm match block
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| [package] | ||
| name = "live-audio-transcription-example" | ||
| version = "0.1.0" | ||
| edition = "2024" | ||
| description = "Live audio transcription (streaming) example using the Foundry Local Rust SDK" | ||
|
|
||
| [dependencies] | ||
| foundry-local-sdk = { path = "../../../sdk/rust" } | ||
| tokio = { version = "1", features = ["rt-multi-thread", "macros"] } | ||
| tokio-stream = "0.1" | ||
| cpal = "0.15" |
292 changes: 292 additions & 0 deletions
292
samples/rust/live-audio-transcription-example/src/main.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,292 @@ | ||
| // Live Audio Transcription — Foundry Local Rust SDK Example | ||
| // | ||
| // Demonstrates real-time microphone-to-text using: | ||
| // Microphone (cpal) → SDK → Core (NativeAOT DLL) → onnxruntime-genai (StreamingProcessor) | ||
| // | ||
| // Usage: | ||
| // cargo run # Live microphone transcription (press ENTER to stop) | ||
| // cargo run -- --synth # Use synthetic 440Hz sine wave instead of microphone | ||
|
|
||
| use std::env; | ||
| use std::io::{self, Write}; | ||
| use std::sync::Arc; | ||
|
|
||
| use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; | ||
| use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; | ||
| use tokio_stream::StreamExt; | ||
|
|
||
| const ALIAS: &str = "nemotron"; | ||
|
|
||
| #[tokio::main] | ||
| async fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
| let use_synth = env::args().any(|a| a == "--synth"); | ||
|
|
||
| println!("==========================================================="); | ||
| println!(" Foundry Local -- Live Audio Transcription Demo (Rust)"); | ||
| println!("==========================================================="); | ||
| println!(); | ||
|
|
||
| // ── 1. Resolve e2e-test-pkgs path ──────────────────────────────────── | ||
| let exe_dir = env::current_exe()?.parent().unwrap().to_path_buf(); | ||
|
|
||
| let manifest_dir = env!("CARGO_MANIFEST_DIR"); | ||
| let e2e_pkgs = std::path::PathBuf::from(manifest_dir) | ||
| .join("..") | ||
| .join("e2e-test-pkgs"); | ||
|
|
||
| let (core_path, model_cache_dir) = if e2e_pkgs.exists() { | ||
| let core = e2e_pkgs | ||
| .canonicalize() | ||
| .expect("Failed to canonicalize e2e-test-pkgs path"); | ||
| let models = core.join("models"); | ||
| println!("Using e2e-test-pkgs:"); | ||
| println!(" Core DLLs: {}", core.display()); | ||
| println!(" Models: {}", models.display()); | ||
| ( | ||
| core.to_string_lossy().into_owned(), | ||
| models.to_string_lossy().into_owned(), | ||
| ) | ||
| } else { | ||
| println!("Using default paths (exe directory)"); | ||
| ( | ||
| exe_dir.to_string_lossy().into_owned(), | ||
| exe_dir.join("models").to_string_lossy().into_owned(), | ||
| ) | ||
| }; | ||
|
|
||
| // ── 2. Initialise the manager ──────────────────────────────────────── | ||
| let config = FoundryLocalConfig::new("foundry_local_samples") | ||
| .library_path(&core_path) | ||
| .model_cache_dir(&model_cache_dir) | ||
| .additional_setting("Bootstrap", "false"); | ||
|
|
||
| let manager = FoundryLocalManager::create(config)?; | ||
| println!("✓ FoundryLocalManager initialized\n"); | ||
|
|
||
| // ── 3. Get the nemotron model ──────────────────────────────────────── | ||
| let model = manager.catalog().get_model(ALIAS).await?; | ||
| println!("Model: {} (id: {})", model.alias(), model.id()); | ||
|
|
||
| if !model.is_cached().await? { | ||
| println!("Downloading model..."); | ||
| model | ||
| .download(Some(|progress: f64| { | ||
| print!("\r {progress:.1}%"); | ||
| io::stdout().flush().ok(); | ||
| })) | ||
| .await?; | ||
| println!(); | ||
| } | ||
|
|
||
| println!("Loading model..."); | ||
| model.load().await?; | ||
| println!("✓ Model loaded\n"); | ||
|
|
||
| // ── 4. Create live transcription session ───────────────────────────── | ||
| let audio_client = model.create_audio_client(); | ||
| let session = Arc::new(audio_client.create_live_transcription_session()); | ||
|
|
||
| println!("Starting live transcription session..."); | ||
| session.start(None).await?; | ||
| println!("✓ Session started\n"); | ||
|
|
||
| // ── 5. Start reading transcription results in background ───────────── | ||
| let mut stream = session.get_transcription_stream().await?; | ||
| let read_task = tokio::spawn(async move { | ||
| let mut count = 0usize; | ||
| while let Some(result) = stream.next().await { | ||
| match result { | ||
| Ok(r) => { | ||
| let text = &r.content[0].text; | ||
| if r.is_final { | ||
| println!(); | ||
| println!(" [FINAL] {text}"); | ||
| io::stdout().flush().ok(); | ||
| } else if !text.is_empty() { | ||
| print!("{text}"); | ||
| io::stdout().flush().ok(); | ||
| } | ||
| count += 1; | ||
| } | ||
| Err(e) => { | ||
| eprintln!("\n [ERROR] Stream error: {e}"); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| count | ||
| }); | ||
|
|
||
| if use_synth { | ||
| // ── 6a. Synthetic audio mode ───────────────────────────────────── | ||
| println!("Generating synthetic PCM audio (440Hz sine wave, 3 seconds)...\n"); | ||
|
|
||
| println!("==========================================================="); | ||
| println!(" PUSHING AUDIO → SDK → Core → onnxruntime-genai"); | ||
| println!("===========================================================\n"); | ||
|
|
||
| let pcm_data = generate_sine_wave_pcm(16000, 3, 440.0); | ||
| let chunk_size = 16000 / 10 * 2; // 100ms chunks | ||
| let mut chunks_pushed = 0; | ||
| for offset in (0..pcm_data.len()).step_by(chunk_size) { | ||
| let end = std::cmp::min(offset + chunk_size, pcm_data.len()); | ||
| session.append(&pcm_data[offset..end], None).await?; | ||
| chunks_pushed += 1; | ||
| } | ||
| println!("Pushed {chunks_pushed} chunks ({} bytes)", pcm_data.len()); | ||
| } else { | ||
| // ── 6b. Live microphone mode ───────────────────────────────────── | ||
| let host = cpal::default_host(); | ||
| let device = host | ||
| .default_input_device() | ||
| .expect("No input audio device available"); | ||
| println!("Microphone: {}", device.name().unwrap_or_default()); | ||
|
|
||
| let default_config = device.default_input_config()?; | ||
| println!( | ||
| "Device default: {} Hz, {} ch, {:?}", | ||
| default_config.sample_rate().0, | ||
| default_config.channels(), | ||
| default_config.sample_format() | ||
| ); | ||
|
|
||
| let device_rate = default_config.sample_rate().0; | ||
| let device_channels = default_config.channels(); | ||
| // BufferSize::Default lets the OS/driver choose the optimal buffer | ||
| // size for the device, typically ~10ms worth of samples. | ||
| let mic_config: cpal::StreamConfig = default_config.into(); | ||
|
|
||
| // Use a sync channel to forward audio from the cpal callback thread | ||
| // to the async runtime. This avoids Arc-cloning the session and | ||
| // spawning a tokio task per mic callback. | ||
| let (audio_tx, mut audio_rx) = tokio::sync::mpsc::channel::<Vec<u8>>(100); | ||
|
|
||
| let input_stream = device.build_input_stream( | ||
| &mic_config, | ||
| move |data: &[f32], _: &cpal::InputCallbackInfo| { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this conversion logic into a named function (e.g. |
||
| let bytes = convert_audio(data, device_channels, device_rate); | ||
| if !bytes.is_empty() { | ||
| let _ = audio_tx.try_send(bytes); | ||
| } | ||
| }, | ||
| |err| eprintln!("Microphone stream error: {err}"), | ||
| None, | ||
| )?; | ||
|
|
||
| input_stream.play()?; | ||
|
|
||
| println!(); | ||
| println!("==========================================================="); | ||
| println!(" LIVE TRANSCRIPTION ACTIVE"); | ||
| println!(" Speak into your microphone."); | ||
| println!(" Transcription appears in real-time."); | ||
| println!(" Press ENTER to stop recording."); | ||
| println!("==========================================================="); | ||
| println!(); | ||
|
|
||
| // Forward audio from channel to the SDK session in a background task | ||
| let session_for_forward = Arc::clone(&session); | ||
| let forward_task = tokio::spawn(async move { | ||
| while let Some(bytes) = audio_rx.recv().await { | ||
| if let Err(e) = session_for_forward.append(&bytes, None).await { | ||
| eprintln!("Append error: {e}"); | ||
| break; | ||
| } | ||
| } | ||
| }); | ||
|
|
||
| // Block until user presses ENTER | ||
| let mut line = String::new(); | ||
| io::stdin().read_line(&mut line)?; | ||
|
|
||
| drop(input_stream); | ||
| // Close the channel so forward_task exits | ||
| // (input_stream drop closes cpal → callback stops → audio_tx dropped) | ||
| forward_task.await?; | ||
| println!("Microphone stopped."); | ||
| } | ||
|
|
||
| // ── 7. Stop session and wait for results ───────────────────────────── | ||
| println!("\nStopping session (flushing remaining audio)..."); | ||
| session.stop(None).await?; | ||
| println!("✓ Session stopped\n"); | ||
|
|
||
| let result_count = read_task.await?; | ||
|
|
||
| println!("==========================================================="); | ||
| println!(" Total transcription results: {result_count}"); | ||
| println!("==========================================================="); | ||
|
|
||
| // ── 8. Cleanup ─────────────────────────────────────────────────────── | ||
| println!("\nUnloading model..."); | ||
| model.unload().await?; | ||
| println!("Done."); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| /// Convert raw f32 audio samples to 16kHz/mono/16-bit PCM bytes. | ||
| /// | ||
| /// Handles stereo-to-mono mixing and sample rate conversion. | ||
| fn convert_audio(data: &[f32], channels: u16, sample_rate: u32) -> Vec<u8> { | ||
| // Mix to mono if multi-channel | ||
| let mono: Vec<f32> = if channels > 1 { | ||
| data.chunks(channels as usize) | ||
| .map(|frame| frame.iter().sum::<f32>() / channels as f32) | ||
| .collect() | ||
| } else { | ||
| data.to_vec() | ||
| }; | ||
|
|
||
| // Resample to 16kHz if needed | ||
| let resampled = if sample_rate != 16000 { | ||
| resample(&mono, sample_rate, 16000) | ||
| } else { | ||
| mono | ||
| }; | ||
|
|
||
| // Convert f32 → 16-bit signed little-endian bytes | ||
| let mut bytes = Vec::with_capacity(resampled.len() * 2); | ||
| for &s in &resampled { | ||
| let clamped = s.clamp(-1.0, 1.0); | ||
| let sample = (clamped * i16::MAX as f32) as i16; | ||
| bytes.extend_from_slice(&sample.to_le_bytes()); | ||
| } | ||
| bytes | ||
| } | ||
|
|
||
| /// Generate synthetic PCM audio (sine wave, 16kHz, 16-bit signed little-endian, mono). | ||
| fn generate_sine_wave_pcm(sample_rate: i32, duration_seconds: i32, frequency: f64) -> Vec<u8> { | ||
| let total_samples = (sample_rate * duration_seconds) as usize; | ||
| let mut pcm_bytes = vec![0u8; total_samples * 2]; | ||
|
|
||
| for i in 0..total_samples { | ||
| let t = i as f64 / sample_rate as f64; | ||
| let sample = | ||
| (i16::MAX as f64 * 0.5 * (2.0 * std::f64::consts::PI * frequency * t).sin()) as i16; | ||
| let bytes = sample.to_le_bytes(); | ||
| pcm_bytes[i * 2] = bytes[0]; | ||
| pcm_bytes[i * 2 + 1] = bytes[1]; | ||
| } | ||
|
|
||
| pcm_bytes | ||
| } | ||
|
|
||
| /// Simple linear-interpolation resampler (e.g. 48kHz → 16kHz). | ||
| fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> { | ||
| if from_rate == to_rate { | ||
| return input.to_vec(); | ||
| } | ||
| let ratio = from_rate as f64 / to_rate as f64; | ||
| let out_len = (input.len() as f64 / ratio).ceil() as usize; | ||
| let mut output = Vec::with_capacity(out_len); | ||
| for i in 0..out_len { | ||
| let src_idx = i as f64 * ratio; | ||
| let idx = src_idx as usize; | ||
| let frac = src_idx - idx as f64; | ||
| let s0 = input[idx.min(input.len() - 1)]; | ||
| let s1 = input[(idx + 1).min(input.len() - 1)]; | ||
| output.push(s0 + (s1 - s0) * frac as f32); | ||
| } | ||
| output | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is default?