Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import Foundation

/// Configuration for Cohere Transcribe CoreML ASR model.
public enum CohereAsrConfig {
/// Sample rate expected by the model (16kHz).
public static let sampleRate: Int = 16000

/// Maximum audio duration in seconds (30s).
public static let maxAudioSeconds: Float = 30.0

/// Maximum number of audio samples (480,000 at 16kHz = 30 seconds).
public static let maxSamples: Int = 480_000

/// Vocabulary size.
public static let vocabSize: Int = 16_384

/// Encoder hidden size (Conformer blocks).
public static let encoderHiddenSize: Int = 1280

/// Decoder hidden size.
public static let decoderHiddenSize: Int = 1024

/// Number of encoder layers.
public static let numEncoderLayers: Int = 48

/// Number of decoder layers.
public static let numDecoderLayers: Int = 8

/// Number of attention heads in decoder.
public static let numDecoderHeads: Int = 8

/// Head dimension (1024 / 8).
public static let headDim: Int = 128

/// Maximum sequence length for decoder KV cache.
public static let maxSeqLen: Int = 108

/// Number of mel bins.
public static let numMelBins: Int = 128

/// Mel spectrogram parameters.
public enum MelSpec {
public static let nFFT: Int = 1024
public static let hopLength: Int = 160
public static let nMels: Int = 128
public static let fMin: Float = 0.0
public static let fMax: Float = 8000.0
public static let preemphasis: Float = 0.97
}

/// Special tokens.
public enum SpecialTokens {
/// Unknown token.
public static let unkToken: Int = 0
/// No speech token.
public static let noSpeechToken: Int = 1
/// Padding token.
public static let padToken: Int = 2
/// End of text / End of sequence token.
public static let eosToken: Int = 3
/// Start of transcript token.
public static let startToken: Int = 4
/// Start of context token.
public static let startOfContext: Int = 7
/// Emotion undefined token.
public static let emoUndefined: Int = 16
/// Punctuation token.
public static let pnc: Int = 5
/// No inverse text normalization.
public static let noitn: Int = 9
/// No timestamp token.
public static let notimestamp: Int = 11
/// No diarization token.
public static let nodiarize: Int = 13
/// Word boundary marker.
public static let wordBoundary: Int = 13764
}

/// Supported languages.
public enum Language: String, CaseIterable, Sendable {
case english = "en"
case french = "fr"
case german = "de"
case spanish = "es"
case italian = "it"
case portuguese = "pt"
case dutch = "nl"
case polish = "pl"
case greek = "el"
case arabic = "ar"
case japanese = "ja"
case chinese = "zh"
case vietnamese = "vi"
case korean = "ko"

public var englishName: String {
switch self {
case .english: return "English"
case .french: return "French"
case .german: return "German"
case .spanish: return "Spanish"
case .italian: return "Italian"
case .portuguese: return "Portuguese"
case .dutch: return "Dutch"
case .polish: return "Polish"
case .greek: return "Greek"
case .arabic: return "Arabic"
case .japanese: return "Japanese"
case .chinese: return "Chinese"
case .vietnamese: return "Vietnamese"
case .korean: return "Korean"
}
}

/// Language token ID (used as start token for conditioned generation).
public var tokenId: Int {
switch self {
case .english: return 62
case .french: return 69
case .german: return 76
case .spanish: return 169
case .italian: return 97
case .portuguese: return 149
case .dutch: return 60
case .polish: return 148
case .greek: return 77
case .arabic: return 28
case .japanese: return 98
case .chinese: return 50
case .vietnamese: return 194
case .korean: return 110
}
}

/// Build the prompt sequence for this language.
///
/// Cohere models expect a specific prompt sequence:
/// 1. Word boundary marker
/// 2. Start of context
/// 3. Start of transcript
/// 4. Emotion undefined
/// 5-6. Language token (repeated twice)
/// 7. Punctuation
/// 8. No inverse text normalization
/// 9. No timestamp
/// 10. No diarization
public var promptSequence: [Int] {
let langToken = tokenId
return [
SpecialTokens.wordBoundary, // ▁
SpecialTokens.startOfContext, // <|startofcontext|>
SpecialTokens.startToken, // <|startoftranscript|>
SpecialTokens.emoUndefined, // <|emo:undefined|>
langToken, // <|en|> (or other language)
langToken, // <|en|> (repeated)
SpecialTokens.pnc, // <|pnc|>
SpecialTokens.noitn, // <|noitn|>
SpecialTokens.notimestamp, // <|notimestamp|>
SpecialTokens.nodiarize, // <|nodiarize|>
]
}
}
}
Loading
Loading