Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion apps/desktop/src/stt/useRunBatch.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, expect, test } from "vitest";

import { getBatchProvider } from "./useRunBatch";
import { getBatchProvider, getSessionSpeakerCount } from "./useRunBatch";

describe("getBatchProvider", () => {
test("maps pyannote to the batch transcription provider", () => {
Expand All @@ -19,3 +19,40 @@ describe("getBatchProvider", () => {
);
});
});

describe("getSessionSpeakerCount", () => {
test("counts distinct session participants plus the current user", () => {
const rows = new Map([
["mapping-1", { session_id: "session-1", human_id: "human-a" }],
["mapping-2", { session_id: "session-1", human_id: "human-a" }],
["mapping-3", { session_id: "session-1", human_id: "human-b" }],
["mapping-4", { session_id: "other-session", human_id: "human-c" }],
]);
const store = {
forEachRow: (_table: string, callback: (rowId: string) => void) => {
for (const rowId of rows.keys()) callback(rowId);
},
getCell: (_table: string, rowId: string, cellId: string) =>
rows.get(rowId)?.[cellId as "session_id" | "human_id"],
};

expect(getSessionSpeakerCount(store as any, "session-1", "self")).toBe(3);
});

test("returns undefined until at least two speakers are known", () => {
const rows = new Map([
["mapping-1", { session_id: "session-1", human_id: "human-a" }],
]);
const store = {
forEachRow: (_table: string, callback: (rowId: string) => void) => {
for (const rowId of rows.keys()) callback(rowId);
},
getCell: (_table: string, rowId: string, cellId: string) =>
rows.get(rowId)?.[cellId as "session_id" | "human_id"],
};

expect(getSessionSpeakerCount(store as any, "session-1", null)).toBe(
undefined,
);
});
});
42 changes: 41 additions & 1 deletion apps/desktop/src/stt/useRunBatch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type RunOptions = {
maxSpeakers?: number;
};

type Store = NonNullable<ReturnType<typeof main.UI.useStore>>;

const DIRECT_BATCH_PROVIDERS: Set<TranscriptionParams["provider"]> = new Set([
"deepgram",
"soniox",
Expand Down Expand Up @@ -81,6 +83,38 @@ export function isStoppedTranscriptionError(error: unknown) {
);
}

export function getSessionSpeakerCount(
store: Store,
sessionId: string,
selfHumanId?: string | null,
): number | undefined {
const humanIds = new Set<string>();

store.forEachRow("mapping_session_participant", (mappingId, _forEachCell) => {
const sid = store.getCell(
"mapping_session_participant",
mappingId,
"session_id",
);
if (sid !== sessionId) return;

const humanId = store.getCell(
"mapping_session_participant",
mappingId,
"human_id",
);
if (typeof humanId === "string" && humanId) {
humanIds.add(humanId);
}
});

if (typeof selfHumanId === "string" && selfHumanId) {
humanIds.add(selfHumanId);
}

return humanIds.size > 1 ? humanIds.size : undefined;
}

export const useRunBatch = (sessionId: string) => {
const store = main.UI.useStore(main.STORE_ID);
const indexes = main.UI.useIndexes(main.STORE_ID);
Expand Down Expand Up @@ -114,6 +148,12 @@ export const useRunBatch = (sessionId: string) => {
const createdAt = new Date().toISOString();
const memoMd = store.getCell("sessions", sessionId, "raw_md");
let transcriptId: string | null = null;
const inferredNumSpeakers =
options?.numSpeakers === undefined &&
options?.minSpeakers === undefined &&
options?.maxSpeakers === undefined
? getSessionSpeakerCount(store, sessionId, user_id)
: undefined;

const handlePersist: BatchPersistCallback | undefined =
options?.handlePersist;
Expand Down Expand Up @@ -232,7 +272,7 @@ export const useRunBatch = (sessionId: string) => {
languages:
options?.languages ??
getTranscriptionLanguages(aiLanguage, spokenLanguages),
num_speakers: options?.numSpeakers,
num_speakers: options?.numSpeakers ?? inferredNumSpeakers,
min_speakers: options?.minSpeakers,
max_speakers: options?.maxSpeakers,
};
Expand Down
33 changes: 11 additions & 22 deletions crates/listener-core/src/actors/listener/adapters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,32 +410,25 @@ fn i16_bytes_to_f32(bytes: &Bytes) -> Vec<f32> {
}

fn build_listen_params(args: &ListenerArgs) -> owhisper_interface::ListenParams {
let adapter_kind =
AdapterKind::from_url_and_languages(&args.base_url, &args.languages, Some(&args.model));
let redemption_time_ms = if args.onboarding { "60" } else { "400" };
let mut custom_query = std::collections::HashMap::from([(
let custom_query = std::collections::HashMap::from([(
"redemption_time_ms".to_string(),
redemption_time_ms.to_string(),
)]);

if adapter_kind == AdapterKind::AssemblyAI
&& let Some(expected_speakers) = assemblyai_expected_speakers(args)
{
custom_query.insert("speaker_labels".to_string(), "true".to_string());
custom_query.insert("max_speakers".to_string(), expected_speakers.to_string());
}
let num_speakers = expected_speakers(args);

owhisper_interface::ListenParams {
model: Some(args.model.clone()),
languages: args.languages.clone(),
sample_rate: super::super::SAMPLE_RATE,
keywords: args.keywords.clone(),
num_speakers,
custom_query: Some(custom_query),
..Default::default()
}
}

fn assemblyai_expected_speakers(args: &ListenerArgs) -> Option<u32> {
fn expected_speakers(args: &ListenerArgs) -> Option<u32> {
let mut participants = args.participant_human_ids.clone();

if let Some(self_human_id) = &args.self_human_id
Expand Down Expand Up @@ -655,31 +648,26 @@ mod tests {
}

#[test]
fn assemblyai_expected_speakers_counts_distinct_participants() {
fn expected_speakers_counts_distinct_participants() {
let mut args = listener_args("https://api.assemblyai.com", "u3-rt-pro");
args.participant_human_ids = vec!["remote".to_string(), "self".to_string()];
args.self_human_id = Some("self".to_string());

assert_eq!(assemblyai_expected_speakers(&args), Some(2));
assert_eq!(expected_speakers(&args), Some(2));
}

#[test]
fn build_listen_params_adds_assemblyai_diarization_hints() {
fn build_listen_params_sets_num_speakers_without_assemblyai_custom_query() {
let mut args = listener_args("https://api.assemblyai.com", "u3-rt-pro");
args.participant_human_ids = vec!["remote".to_string()];
args.self_human_id = Some("self".to_string());

let params = build_listen_params(&args);
let custom_query = params.custom_query.expect("custom query");

assert_eq!(
custom_query.get("speaker_labels").map(String::as_str),
Some("true")
);
assert_eq!(
custom_query.get("max_speakers").map(String::as_str),
Some("2")
);
assert_eq!(params.num_speakers, Some(2));
assert!(!custom_query.contains_key("speaker_labels"));
assert!(!custom_query.contains_key("max_speakers"));
}

#[test]
Expand All @@ -691,6 +679,7 @@ mod tests {
let params = build_listen_params(&args);
let custom_query = params.custom_query.expect("custom query");

assert_eq!(params.num_speakers, Some(2));
assert!(!custom_query.contains_key("speaker_labels"));
assert!(!custom_query.contains_key("max_speakers"));
}
Expand Down
69 changes: 47 additions & 22 deletions crates/owhisper-client/src/adapter/assemblyai/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,13 @@ impl RealtimeSttAdapter for AssemblyAIAdapter {
query_pairs.append_pair("max_turn_silence", max_silence);
}

if matches!(resolved_model, ResolvedLiveModel::U3RtPro)
&& let Some(custom) = &params.custom_query
{
if custom
.get("speaker_labels")
.is_some_and(|value| value == "true")
{
Comment thread
cursor[bot] marked this conversation as resolved.
if matches!(resolved_model, ResolvedLiveModel::U3RtPro) {
if Self::streaming_speaker_labels_enabled(params) {
query_pairs.append_pair("speaker_labels", "true");
}

if let Some(max_speakers) = custom.get("max_speakers") {
query_pairs.append_pair("max_speakers", max_speakers);
if let Some(max_speakers) = Self::streaming_max_speakers(params) {
query_pairs.append_pair("max_speakers", &max_speakers.to_string());
}
}

Expand Down Expand Up @@ -232,6 +227,27 @@ impl AssemblyAIAdapter {
}
}

fn streaming_speaker_labels_enabled(params: &ListenParams) -> bool {
params.num_speakers.is_some()
|| params.min_speakers.is_some()
|| params.max_speakers.is_some()
|| params
.custom_query
.as_ref()
.and_then(|custom| custom.get("speaker_labels"))
.is_some_and(|value| value == "true")
}
Comment thread
cursor[bot] marked this conversation as resolved.

fn streaming_max_speakers(params: &ListenParams) -> Option<u32> {
params.max_speakers.or(params.num_speakers).or_else(|| {
params
.custom_query
.as_ref()
.and_then(|custom| custom.get("max_speakers"))
.and_then(|value| value.parse().ok())
})
}

fn parse_speaker_label(label: Option<&str>) -> Option<i32> {
let label = label?.trim();
if label.is_empty() || label.eq_ignore_ascii_case("unknown") {
Expand Down Expand Up @@ -339,8 +355,6 @@ impl ResolvedLiveModel {

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use hypr_language::ISO639;
use owhisper_interface::ListenParams;
use owhisper_interface::stream::StreamResponse;
Expand Down Expand Up @@ -424,10 +438,7 @@ mod tests {
API_BASE,
&owhisper_interface::ListenParams {
model: Some("u3-rt-pro".to_string()),
custom_query: Some(HashMap::from([
("speaker_labels".to_string(), "true".to_string()),
("max_speakers".to_string(), "3".to_string()),
])),
num_speakers: Some(3),
..Default::default()
},
1,
Expand All @@ -439,14 +450,28 @@ mod tests {
}

#[test]
fn test_whisper_fallback_omits_streaming_diarization_hints() {
fn test_streaming_min_speakers_enables_diarization() {
let url = AssemblyAIAdapter.build_ws_url(
API_BASE,
&owhisper_interface::ListenParams {
model: Some("u3-rt-pro".to_string()),
min_speakers: Some(2),
..Default::default()
},
1,
);

let query = url.query().expect("query string");
assert!(query.contains("speaker_labels=true"));
assert!(!query.contains("max_speakers"));
}

#[test]
fn test_streaming_diarization_hints_skip_whisper_fallback() {
let url = AssemblyAIAdapter.build_ws_url(
API_BASE,
&owhisper_interface::ListenParams {
custom_query: Some(HashMap::from([
("speaker_labels".to_string(), "true".to_string()),
("max_speakers".to_string(), "3".to_string()),
])),
num_speakers: Some(3),
languages: vec![ISO639::Ko.into()],
..Default::default()
},
Expand All @@ -455,8 +480,8 @@ mod tests {

let query = url.query().expect("query string");
assert!(query.contains("speech_model=whisper-rt"));
assert!(!query.contains("speaker_labels=true"));
assert!(!query.contains("max_speakers=3"));
assert!(!query.contains("speaker_labels"));
assert!(!query.contains("max_speakers"));
}

#[test]
Expand Down
28 changes: 28 additions & 0 deletions crates/owhisper-client/src/adapter/elevenlabs/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ impl ElevenLabsAdapter {
.text("diarize", "true")
.text("timestamps_granularity", "word");

if let Some(num_speakers) = Self::num_speakers_hint(params) {
form = form.text("num_speakers", num_speakers.to_string());
}

if let Some(lang) = params.languages.first() {
form = form.text("language_code", lang.iso639().code().to_string());
}
Expand Down Expand Up @@ -116,6 +120,10 @@ impl ElevenLabsAdapter {
Ok(Self::convert_to_batch_response(transcript))
}

fn num_speakers_hint(params: &ListenParams) -> Option<u32> {
params.num_speakers.or(params.max_speakers)
}

fn convert_to_batch_response(response: TranscriptResponse) -> BatchResponse {
let words: Vec<BatchWord> = response
.words
Expand Down Expand Up @@ -164,6 +172,26 @@ mod tests {
use super::*;
use crate::http_client::create_client;

#[test]
fn num_speakers_hint_prefers_exact_count_then_max() {
let exact = ListenParams {
num_speakers: Some(3),
max_speakers: Some(5),
..Default::default()
};
let ranged = ListenParams {
max_speakers: Some(5),
..Default::default()
};

assert_eq!(ElevenLabsAdapter::num_speakers_hint(&exact), Some(3));
assert_eq!(ElevenLabsAdapter::num_speakers_hint(&ranged), Some(5));
assert_eq!(
ElevenLabsAdapter::num_speakers_hint(&ListenParams::default()),
None
);
}

#[test]
fn speaker_labeled_words_use_mixed_capture_channel() {
let response = TranscriptResponse {
Expand Down
Loading
Loading