Skip to content

Commit a274231

Browse files
committed
Refactor _compat.save function
When dealing with vorbis format, FFmpeg expects "ogg" container/extension with "vorbis" encoder. It does not recognize "vorbis" container/extension. libsox-based torchaudio I/O used to handle vorbis extension. This commit refactors the internal of save argument and adds support to vorbis as extension for those cases with FFmpeg backend. This also fixes the case of mp3 #3385
1 parent af932cc commit a274231

File tree

3 files changed

+89
-51
lines changed

3 files changed

+89
-51
lines changed

test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from parameterized import parameterized
88
from torchaudio._backend.utils import get_load_func
99
from torchaudio._internal import module_utils as _mod_utils
10-
from torchaudio.io._compat import _get_encoder
10+
from torchaudio.io._compat import _parse_save_args
1111

1212
from torchaudio_unittest.backend.dispatcher.sox.common import name_func
1313
from torchaudio_unittest.common_utils import (
@@ -91,15 +91,15 @@ def assert_format(
9191

9292
# 3. Convert to wav with ffmpeg
9393
if normalize:
94-
acodec = "pcm_f32le"
94+
encoder = "pcm_f32le"
9595
else:
9696
encoding_map = {
9797
"floating-point": "PCM_F",
9898
"signed-integer": "PCM_S",
9999
"unsigned-integer": "PCM_U",
100100
}
101-
acodec = _get_encoder(data.dtype, "wav", encoding_map.get(encoding), bit_depth)
102-
_convert_audio_file(path, ref_path, acodec=acodec)
101+
_, encoder, _ = _parse_save_args(format, format, encoding_map.get(encoding), bit_depth)
102+
_convert_audio_file(path, ref_path, encoder=encoder)
103103

104104
# 4. Load wav with scipy
105105
data_ref = load_wav(ref_path, normalize=normalize)[0]
@@ -277,7 +277,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
277277
"""`self._load` can load opus file correctly."""
278278
ops_path = get_asset_path("io", f"{bitrate}_{compression_level}_{num_channels}ch.opus")
279279
wav_path = self.get_temp_path(f"{bitrate}_{compression_level}_{num_channels}ch.opus.wav")
280-
_convert_audio_file(ops_path, wav_path, acodec="pcm_f32le")
280+
_convert_audio_file(ops_path, wav_path, encoder="pcm_f32le")
281281

282282
expected, sample_rate = load_wav(wav_path)
283283
found, sr = self._load(ops_path)

test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from parameterized import parameterized
1010
from torchaudio._backend.utils import get_save_func
11-
from torchaudio.io._compat import _get_encoder, _get_encoder_format
11+
from torchaudio.io._compat import _parse_save_args
1212

1313
from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
1414
from torchaudio_unittest.common_utils import (
@@ -24,12 +24,14 @@
2424
)
2525

2626

27-
def _convert_audio_file(src_path, dst_path, format=None, acodec=None):
28-
command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"]
29-
if format:
30-
command += ["-sample_fmt", format]
31-
if acodec:
32-
command += ["-acodec", acodec]
27+
def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None):
28+
command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"]
29+
if muxer:
30+
command += ["-f", muxer]
31+
if encoder:
32+
command += ["-acodec", encoder]
33+
if sample_fmt:
34+
command += ["-sample_fmt", sample_fmt]
3335
command += [dst_path]
3436
print(" ".join(command), file=sys.stderr)
3537
subprocess.run(command, check=True)
@@ -100,8 +102,10 @@ def assert_save_consistency(
100102
# 2.1. Convert the original wav to target format with torchaudio
101103
data = load_wav(src_path, normalize=False)[0]
102104
if test_mode == "path":
103-
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
105+
ext = format
106+
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
104107
elif test_mode == "fileobj":
108+
ext = None
105109
with open(tgt_path, "bw") as file_:
106110
self._save(
107111
file_,
@@ -113,6 +117,7 @@ def assert_save_consistency(
113117
)
114118
elif test_mode == "bytesio":
115119
file_ = io.BytesIO()
120+
ext = None
116121
self._save(
117122
file_,
118123
data,
@@ -127,16 +132,15 @@ def assert_save_consistency(
127132
else:
128133
raise ValueError(f"Unexpected test mode: {test_mode}")
129134
# 2.2. Convert the target format to wav with ffmpeg
130-
_convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le")
135+
_convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le")
131136
# 2.3. Load with SciPy
132137
found = load_wav(tst_path, normalize=False)[0]
133138

134139
# 3.1. Convert the original wav to target format with ffmpeg
135-
acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample)
136-
sample_fmt = _get_encoder_format(format, bits_per_sample)
137-
_convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt)
140+
muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
141+
_convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt)
138142
# 3.2. Convert the target format to wav with ffmpeg
139-
_convert_audio_file(sox_path, ref_path, acodec="pcm_f32le")
143+
_convert_audio_file(sox_path, ref_path, encoder="pcm_f32le")
140144
# 3.3. Load with SciPy
141145
expected = load_wav(ref_path, normalize=False)[0]
142146

torchaudio/io/_compat.py

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def load_audio_fileobj(
102102
format: Optional[str] = None,
103103
buffer_size: int = 4096,
104104
) -> Tuple[torch.Tensor, int]:
105-
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
105+
demuxer = "ogg" if format == "vorbis" else format
106+
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, demuxer, None, buffer_size)
106107
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
107108
filter = _get_load_filter(frame_offset, num_frames, convert)
108109
waveform = _load_audio_fileobj(s, filter, channels_first)
@@ -131,7 +132,7 @@ def _native_endianness() -> str:
131132
return "be"
132133

133134

134-
def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str:
135+
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
135136
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
136137
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
137138
endianness = _native_endianness()
@@ -148,49 +149,80 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int
148149
if bits_per_sample == 8:
149150
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
150151
return f"pcm_s{bits_per_sample}{endianness}"
151-
elif encoding == "PCM_U":
152+
if encoding == "PCM_U":
152153
if bits_per_sample in (None, 8):
153154
return "pcm_u8"
154155
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
155-
elif encoding == "PCM_F":
156+
if encoding == "PCM_F":
156157
if not bits_per_sample:
157158
bits_per_sample = 32
158159
if bits_per_sample in (32, 64):
159160
return f"pcm_f{bits_per_sample}{endianness}"
160161
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
161-
elif encoding == "ULAW":
162+
if encoding == "ULAW":
162163
if bits_per_sample in (None, 8):
163164
return "pcm_mulaw"
164165
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
165-
elif encoding == "ALAW":
166+
if encoding == "ALAW":
166167
if bits_per_sample in (None, 8):
167168
return "pcm_alaw"
168169
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
169170
raise ValueError(f"WAV encoding {encoding} is not supported.")
170171

171172

172-
def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str:
173-
if format == "wav":
174-
return _get_encoder_for_wav(dtype, encoding, bits_per_sample)
175-
if format == "flac":
176-
return "flac"
177-
if format in ("ogg", "vorbis"):
178-
if encoding or bits_per_sample:
179-
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
180-
return "vorbis"
181-
return format
173+
def _get_flac_sample_fmt(bps):
174+
if bps is None or bps == 16:
175+
return "s16"
176+
if bps == 24:
177+
return "s32"
178+
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
182179

183180

184-
def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str:
185-
if format == "flac":
186-
if not bits_per_sample:
187-
return "s16"
188-
if bits_per_sample == 24:
189-
return "s32"
190-
if bits_per_sample == 16:
191-
return "s16"
192-
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).")
193-
return None
181+
def _parse_save_args(
182+
ext: Optional[str],
183+
format: Optional[str],
184+
encoding: Optional[str],
185+
bps: Optional[int],
186+
):
187+
# torchaudio's save function accepts the followings, which do not 1to1 map
188+
# to FFmpeg.
189+
#
190+
# - format: audio format
191+
# - bits_per_sample: encoder sample format
192+
# - encoding: such as PCM_U8.
193+
#
194+
# In FFmpeg, format is specified with the following three (and more)
195+
#
196+
# - muxer: could be audio format or container format.
197+
# the one we passed to the constructor of StreamWriter
198+
# - encoder: the audio encoder used to encode audio
199+
# - encoder sample format: the format used by encoder to encode audio.
200+
#
201+
# If encoder sample format is different from source sample format, StreamWriter
202+
# will insert a filter automatically.
203+
#
204+
if format == "wav" or (format is None and ext == "wav"):
205+
# wav is special because it supports different encoding through encoders
206+
# each encoder only supports one encoder format
207+
muxer = "wav"
208+
encoder = _get_encoder_for_wav(encoding, bps)
209+
sample_fmt = None
210+
elif format == "vorbis" or (format is None and ext == "vorbis"):
211+
# FFpmeg does not recognize vorbis extension, while libsox used to do.
212+
# For the sake of bakward compatibility, (and the simplicity),
213+
# we support the case where users want to do save("foo.vorbis")
214+
muxer = "ogg"
215+
encoder = "vorbis"
216+
sample_fmt = None
217+
else:
218+
muxer = format
219+
encoder = None
220+
sample_fmt = None
221+
if format == "flac" or format is None and ext == "flac":
222+
sample_fmt = _get_flac_sample_fmt(bps)
223+
if format == "ogg" or format is None and ext == "ogg":
224+
sample_fmt = _get_flac_sample_fmt(bps)
225+
return muxer, encoder, sample_fmt
194226

195227

196228
# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
@@ -204,25 +236,27 @@ def save_audio(
204236
bits_per_sample: Optional[int] = None,
205237
buffer_size: int = 4096,
206238
) -> None:
239+
ext = None
207240
if hasattr(uri, "write"):
208241
if format is None:
209242
raise RuntimeError("'format' is required when saving to file object.")
210243
else:
211244
uri = os.path.normpath(uri)
212-
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
213-
if format is None:
214-
tokens = str(uri).split(".")
215-
if len(tokens) > 1:
216-
format = tokens[-1].lower()
245+
if tokens := str(uri).split(".")[1:]:
246+
ext = tokens[-1].lower()
247+
248+
muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
217249

218250
if channels_first:
219251
src = src.T
252+
253+
s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
220254
s.add_audio_stream(
221255
sample_rate,
222256
num_channels=src.size(-1),
223257
format=_get_sample_format(src.dtype),
224-
encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample),
225-
encoder_format=_get_encoder_format(format, bits_per_sample),
258+
encoder=encoder,
259+
encoder_format=enc_fmt,
226260
)
227261
with s.open():
228262
s.write_audio_chunk(0, src)

0 commit comments

Comments
 (0)