diff --git a/examples/asr/emformer_rnnt/pipeline_demo.py b/examples/asr/emformer_rnnt/pipeline_demo.py index 782bc1d539..eacba2503a 100644 --- a/examples/asr/emformer_rnnt/pipeline_demo.py +++ b/examples/asr/emformer_rnnt/pipeline_demo.py @@ -65,9 +65,9 @@ def run_eval_streaming(args): with torch.no_grad(): features, length = streaming_feature_extractor(segment) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) - hypothesis = hypos[0] - transcript = token_processor(hypothesis[0], lstrip=False) - print(transcript, end="", flush=True) + hypothesis = hypos + transcript = token_processor(hypos[0][0], lstrip=True) + print(transcript, end="\r", flush=True) print() # Non-streaming decode. diff --git a/examples/tutorials/online_asr_tutorial.py b/examples/tutorials/online_asr_tutorial.py index 45c65b41c9..faf1b72602 100644 --- a/examples/tutorials/online_asr_tutorial.py +++ b/examples/tutorials/online_asr_tutorial.py @@ -39,6 +39,7 @@ # -------------- # +import os import torch import torchaudio @@ -222,9 +223,9 @@ def run_inference(num_iter=100): segment = cacher(chunk[:, 0]) features, length = feature_extractor(segment) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) - hypothesis = hypos[0] - transcript = token_processor(hypothesis[0], lstrip=False) - print(transcript, end="", flush=True) + hypothesis = hypos + transcript = token_processor(hypos[0][0], lstrip=False) + print(transcript, end="\r", flush=True) chunks.append(chunk) feats.append(features) diff --git a/test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py b/test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py index 8560564596..5bfabdfbc2 100644 --- a/test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py +++ b/test/torchaudio_unittest/models/rnnt_decoder/rnnt_decoder_test_impl.py @@ -99,7 +99,7 @@ def test_torchscript_consistency_infer(self): self.assertEqual(res, scripted_res) state = res[1] - hypo = res[0][0] + hypo = res[0] scripted_state = scripted_res[1] - scripted_hypo = scripted_res[0][0] + scripted_hypo = scripted_res[0] diff --git a/torchaudio/models/rnnt_decoder.py b/torchaudio/models/rnnt_decoder.py index 045e642d0a..74d65ca2de 100644 --- a/torchaudio/models/rnnt_decoder.py +++ b/torchaudio/models/rnnt_decoder.py @@ -109,13 +109,9 @@ def __init__( self.step_max_tokens = step_max_tokens - def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]: - if hypo is not None: - token = _get_hypo_tokens(hypo)[-1] - state = _get_hypo_state(hypo) - else: - token = self.blank - state = None + def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]: + token = self.blank + state = None one_tensor = torch.tensor([1], device=device) pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) @@ -230,14 +226,14 @@ def _gen_new_hypos( def _search( self, enc_out: torch.Tensor, - hypo: Optional[Hypothesis], + hypo: Optional[List[Hypothesis]], beam_width: int, ) -> List[Hypothesis]: n_time_steps = enc_out.shape[1] device = enc_out.device a_hypos: List[Hypothesis] = [] - b_hypos = self._init_b_hypos(hypo, device) + b_hypos = self._init_b_hypos(device) if hypo is None else hypo for t in range(n_time_steps): a_hypos = b_hypos b_hypos = torch.jit.annotate(List[Hypothesis], []) @@ -263,7 +259,7 @@ def _search( if a_hypos: symbols_current_t += 1 - _, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width) + _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width) b_hypos = [b_hypos[idx] for idx in sorted_idx] return b_hypos @@ -290,8 +286,8 @@ def forward(self, input: torch.Tensor, length: torch.Tensor, beam_width: int) -> if length.shape != () and length.shape != (1,): raise ValueError("length must be of shape () or (1,)") - if input.dim() == 0: - input = input.unsqueeze(0) + if length.dim() == 0: + length = length.unsqueeze(0) enc_out, _ = self.model.transcribe(input, length) return self._search(enc_out, None, beam_width) @@ -303,7 +299,7 @@ def infer( length: torch.Tensor, beam_width: int, state: Optional[List[List[torch.Tensor]]] = None, - hypothesis: Optional[Hypothesis] = None, + hypothesis: Optional[List[Hypothesis]] = None, ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: r"""Performs beam search for the given input sequence in streaming mode.