Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 25 additions & 50 deletions Sources/FluidAudio/Diarizer/LS-EEND/LSEENDDiarizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ public final class LSEENDDiarizer: Diarizer {

private var _engine: LSEENDInferenceHelper?
private var _session: LSEENDStreamingSession?
private var _melSpectrogram: AudioMelSpectrogram?
private var _timeline: DiarizerTimeline
private var _numFramesProcessed: Int = 0
private var _timelineConfig: DiarizerTimelineConfig
private var _visibleStartFrameOffset: Int = 0
private var _cachedConverter: AudioConverter?

// Audio buffering
private var pendingAudio: [Float] = []
Expand Down Expand Up @@ -154,12 +154,11 @@ public final class LSEENDDiarizer: Diarizer {
/// - Parameter descriptor: Model descriptor specifying variant and file paths
public func initialize(descriptor: LSEENDModelDescriptor) throws {
let engine = try LSEENDInferenceHelper(descriptor: descriptor, computeUnits: computeUnits)
let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig)

lock.withLock {
updateTimelineConfig(engine: engine)
_engine = engine
_melSpectrogram = melSpectrogram
_cachedConverter = nil
_timeline = DiarizerTimeline(config: _timelineConfig)
_session = nil
Comment thread
SGD2718 marked this conversation as resolved.
resetBuffersLocked()
Expand All @@ -175,12 +174,10 @@ public final class LSEENDDiarizer: Diarizer {

/// Initialize with a pre-loaded engine.
public func initialize(engine: LSEENDInferenceHelper) {
let melSpectrogram = Self.createMelSpectrogram(featureConfig: engine.featureConfig)

lock.withLock {
updateTimelineConfig(engine: engine)
_engine = engine
_melSpectrogram = melSpectrogram
_cachedConverter = nil
_timeline = DiarizerTimeline(config: _timelineConfig)
_session = nil
Comment thread
SGD2718 marked this conversation as resolved.
resetBuffersLocked()
Expand Down Expand Up @@ -270,7 +267,7 @@ public final class LSEENDDiarizer: Diarizer {

if _session == nil {
_session = try engine.createSession(
inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!)
inputSampleRate: engine.targetSampleRate)
}
Comment thread
SGD2718 marked this conversation as resolved.
guard let session = _session else {
return nil
Expand All @@ -289,12 +286,11 @@ public final class LSEENDDiarizer: Diarizer {
}

if let update {
let numSpeakers = engine.metadata.realOutputDim
let result = DiarizerChunkResult(
startFrame: max(0, update.startFrame - _visibleStartFrameOffset),
finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: update.probabilities.values,
finalizedFrameCount: update.probabilities.rows,
tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers),
tentativePredictions: update.previewProbabilities.values,
tentativeFrameCount: update.previewProbabilities.rows
)
_numFramesProcessed += result.finalizedFrameCount
Expand Down Expand Up @@ -409,8 +405,7 @@ public final class LSEENDDiarizer: Diarizer {

// Lazily create session on first process call
if _session == nil {
_session = try engine.createSession(
inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!)
_session = try engine.createSession(inputSampleRate: engine.targetSampleRate)
}
Comment thread
SGD2718 marked this conversation as resolved.
guard let session = _session else { return nil }

Expand All @@ -423,12 +418,11 @@ public final class LSEENDDiarizer: Diarizer {
return nil
}

let numSpeakers = engine.metadata.realOutputDim
let result = DiarizerChunkResult(
startFrame: max(0, update.startFrame - _visibleStartFrameOffset),
finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: update.probabilities.values,
finalizedFrameCount: update.probabilities.rows,
tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers),
tentativePredictions: update.previewProbabilities.values,
tentativeFrameCount: update.previewProbabilities.rows
)

Expand Down Expand Up @@ -563,17 +557,15 @@ public final class LSEENDDiarizer: Diarizer {
retainedSession
} else {
try engine.createSession(
inputSampleRate: engine.targetSampleRate, melSpectrogram: _melSpectrogram!)
inputSampleRate: engine.targetSampleRate)
}
Comment thread
SGD2718 marked this conversation as resolved.
let numSpeakers = engine.metadata.realOutputDim

// Push all audio at once
if let update = try session.pushAudio(normalized) {
let chunk = DiarizerChunkResult(
startFrame: max(0, update.startFrame - _visibleStartFrameOffset),
finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: update.probabilities.values,
finalizedFrameCount: update.probabilities.rows,
tentativePredictions: flattenRowMajor(update.previewProbabilities, numSpeakers: numSpeakers),
tentativePredictions: update.previewProbabilities.values,
tentativeFrameCount: update.previewProbabilities.rows
)
_numFramesProcessed += chunk.finalizedFrameCount
Expand All @@ -586,7 +578,7 @@ public final class LSEENDDiarizer: Diarizer {
if let finalUpdate = try session.finalize() {
let chunk = DiarizerChunkResult(
startFrame: max(0, finalUpdate.startFrame - _visibleStartFrameOffset),
finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: finalUpdate.probabilities.values,
finalizedFrameCount: finalUpdate.probabilities.rows,
tentativePredictions: [],
tentativeFrameCount: 0
Expand Down Expand Up @@ -623,7 +615,7 @@ public final class LSEENDDiarizer: Diarizer {
lock.withLock {
_engine = nil
_session = nil
_melSpectrogram = nil
_cachedConverter = nil
_timeline.reset()
resetBuffersLocked()
logger.info("LS-EEND resources cleaned up")
Expand All @@ -644,8 +636,7 @@ public final class LSEENDDiarizer: Diarizer {
lock.lock()
defer { lock.unlock() }

guard let engine = _engine, let session = _session else { return nil }
let numSpeakers = engine.metadata.realOutputDim
guard let session = _session else { return nil }
var lastResult: DiarizerChunkResult?

// Flush pending audio first — clear unconditionally so failed audio isn't retained.
Expand All @@ -656,7 +647,7 @@ public final class LSEENDDiarizer: Diarizer {
if let update = pushedUpdate {
let flushedResult = DiarizerChunkResult(
startFrame: _numFramesProcessed,
finalizedPredictions: flattenRowMajor(update.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: update.probabilities.values,
finalizedFrameCount: update.probabilities.rows,
Comment thread
SGD2718 marked this conversation as resolved.
tentativePredictions: [],
tentativeFrameCount: 0
Expand All @@ -670,7 +661,7 @@ public final class LSEENDDiarizer: Diarizer {
if let finalUpdate = try session.finalize() {
let finalResult = DiarizerChunkResult(
startFrame: _numFramesProcessed,
finalizedPredictions: flattenRowMajor(finalUpdate.probabilities, numSpeakers: numSpeakers),
finalizedPredictions: finalUpdate.probabilities.values,
finalizedFrameCount: finalUpdate.probabilities.rows,
tentativePredictions: [],
tentativeFrameCount: 0
Expand Down Expand Up @@ -704,34 +695,18 @@ public final class LSEENDDiarizer: Diarizer {
return nil
}

return try AudioConverter(sampleRate: Double(engine.targetSampleRate))
.resample(Array(samples), from: sourceSampleRate)
}

/// Create a new mel spectrogram instance owned by this diarizer.
private static func createMelSpectrogram(featureConfig: LSEENDFeatureConfig) -> AudioMelSpectrogram {
AudioMelSpectrogram(
sampleRate: featureConfig.sampleRate,
nMels: featureConfig.nMels,
nFFT: featureConfig.nFFT,
hopLength: featureConfig.hopLength,
winLength: featureConfig.winLength,
preemph: 0,
padTo: 1,
logFloor: 1e-10,
logFloorMode: .clamped,
windowPeriodic: true
)
let converter =
_cachedConverter
?? {
let c = AudioConverter(sampleRate: Double(engine.targetSampleRate))
_cachedConverter = c
return c
}()
return try converter.resample(Array(samples), from: sourceSampleRate)
}

private func updateTimelineConfig(engine: LSEENDInferenceHelper) {
self._timelineConfig.numSpeakers = engine.metadata.realOutputDim
self._timelineConfig.frameDurationSeconds = Float(1.0 / engine.modelFrameHz)
}

/// Convert an LSEENDMatrix to a flat [Float] in row-major layout.
private func flattenRowMajor(_ matrix: LSEENDMatrix, numSpeakers: Int) -> [Float] {
guard matrix.rows > 0, matrix.columns > 0 else { return [] }
return matrix.values
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,12 @@ public final class LSEENDStreamingSession {

let previewFullLogits: LSEENDMatrix
if includePreview {
let previewState = try state.copy()
// flushTail does not mutate the passed-in state — it reassigns a local
// variable on each step, leaving self.state untouched. finalize() relies
// on this same guarantee. Skipping the former state.copy() eliminates
// 6 × cloneAlignedMultiArray per pushAudio call.
let pending = totalFeatureFrames - emittedFrames
previewFullLogits = try flushTail(from: previewState, pendingFrames: pending)
previewFullLogits = try flushTail(from: state, pendingFrames: pending)
} else {
previewFullLogits = .empty(columns: engine.decodeMaxSpeakers)
}
Expand Down
Loading