@@ -102,7 +102,8 @@ def load_audio_fileobj(
102
102
format : Optional [str ] = None ,
103
103
buffer_size : int = 4096 ,
104
104
) -> Tuple [torch .Tensor , int ]:
105
- s = torchaudio .lib ._torchaudio_ffmpeg .StreamReaderFileObj (src , format , None , buffer_size )
105
+ demuxer = "ogg" if format == "vorbis" else format
106
+ s = torchaudio .lib ._torchaudio_ffmpeg .StreamReaderFileObj (src , demuxer , None , buffer_size )
106
107
sample_rate = int (s .get_src_stream_info (s .find_best_audio_stream ()).sample_rate )
107
108
filter = _get_load_filter (frame_offset , num_frames , convert )
108
109
waveform = _load_audio_fileobj (s , filter , channels_first )
@@ -131,7 +132,7 @@ def _native_endianness() -> str:
131
132
return "be"
132
133
133
134
134
- def _get_encoder_for_wav (dtype : torch . dtype , encoding : str , bits_per_sample : int ) -> str :
135
+ def _get_encoder_for_wav (encoding : str , bits_per_sample : int ) -> str :
135
136
if bits_per_sample not in {None , 8 , 16 , 24 , 32 , 64 }:
136
137
raise ValueError (f"Invalid bits_per_sample { bits_per_sample } for WAV encoding." )
137
138
endianness = _native_endianness ()
@@ -148,49 +149,80 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int
148
149
if bits_per_sample == 8 :
149
150
raise ValueError ("For WAV signed PCM, 8-bit encoding is not supported." )
150
151
return f"pcm_s{ bits_per_sample } { endianness } "
151
- elif encoding == "PCM_U" :
152
+ if encoding == "PCM_U" :
152
153
if bits_per_sample in (None , 8 ):
153
154
return "pcm_u8"
154
155
raise ValueError ("For WAV unsigned PCM, only 8-bit encoding is supported." )
155
- elif encoding == "PCM_F" :
156
+ if encoding == "PCM_F" :
156
157
if not bits_per_sample :
157
158
bits_per_sample = 32
158
159
if bits_per_sample in (32 , 64 ):
159
160
return f"pcm_f{ bits_per_sample } { endianness } "
160
161
raise ValueError ("For WAV float PCM, only 32- and 64-bit encodings are supported." )
161
- elif encoding == "ULAW" :
162
+ if encoding == "ULAW" :
162
163
if bits_per_sample in (None , 8 ):
163
164
return "pcm_mulaw"
164
165
raise ValueError ("For WAV PCM mu-law, only 8-bit encoding is supported." )
165
- elif encoding == "ALAW" :
166
+ if encoding == "ALAW" :
166
167
if bits_per_sample in (None , 8 ):
167
168
return "pcm_alaw"
168
169
raise ValueError ("For WAV PCM A-law, only 8-bit encoding is supported." )
169
170
raise ValueError (f"WAV encoding { encoding } is not supported." )
170
171
171
172
172
- def _get_encoder (dtype : torch .dtype , format : str , encoding : str , bits_per_sample : int ) -> str :
173
- if format == "wav" :
174
- return _get_encoder_for_wav (dtype , encoding , bits_per_sample )
175
- if format == "flac" :
176
- return "flac"
177
- if format in ("ogg" , "vorbis" ):
178
- if encoding or bits_per_sample :
179
- raise ValueError ("ogg/vorbis does not support encoding/bits_per_sample." )
180
- return "vorbis"
181
- return format
173
+ def _get_flac_sample_fmt (bps ):
174
+ if bps is None or bps == 16 :
175
+ return "s16"
176
+ if bps == 24 :
177
+ return "s32"
178
+ raise ValueError (f"FLAC only supports bits_per_sample values of 16 and 24 ({ bps } specified)." )
182
179
183
180
184
- def _get_encoder_format (format : str , bits_per_sample : Optional [int ]) -> str :
185
- if format == "flac" :
186
- if not bits_per_sample :
187
- return "s16"
188
- if bits_per_sample == 24 :
189
- return "s32"
190
- if bits_per_sample == 16 :
191
- return "s16"
192
- raise ValueError (f"FLAC only supports bits_per_sample values of 16 and 24 ({ bits_per_sample } specified)." )
193
- return None
181
+ def _parse_save_args (
182
+ ext : Optional [str ],
183
+ format : Optional [str ],
184
+ encoding : Optional [str ],
185
+ bps : Optional [int ],
186
+ ):
187
+ # torchaudio's save function accepts the followings, which do not 1to1 map
188
+ # to FFmpeg.
189
+ #
190
+ # - format: audio format
191
+ # - bits_per_sample: encoder sample format
192
+ # - encoding: such as PCM_U8.
193
+ #
194
+ # In FFmpeg, format is specified with the following three (and more)
195
+ #
196
+ # - muxer: could be audio format or container format.
197
+ # the one we passed to the constructor of StreamWriter
198
+ # - encoder: the audio encoder used to encode audio
199
+ # - encoder sample format: the format used by encoder to encode audio.
200
+ #
201
+ # If encoder sample format is different from source sample format, StreamWriter
202
+ # will insert a filter automatically.
203
+ #
204
+ if format == "wav" or (format is None and ext == "wav" ):
205
+ # wav is special because it supports different encoding through encoders
206
+ # each encoder only supports one encoder format
207
+ muxer = "wav"
208
+ encoder = _get_encoder_for_wav (encoding , bps )
209
+ sample_fmt = None
210
+ elif format == "vorbis" or (format is None and ext == "vorbis" ):
211
+ # FFpmeg does not recognize vorbis extension, while libsox used to do.
212
+ # For the sake of bakward compatibility, (and the simplicity),
213
+ # we support the case where users want to do save("foo.vorbis")
214
+ muxer = "ogg"
215
+ encoder = "vorbis"
216
+ sample_fmt = None
217
+ else :
218
+ muxer = format
219
+ encoder = None
220
+ sample_fmt = None
221
+ if format == "flac" or format is None and ext == "flac" :
222
+ sample_fmt = _get_flac_sample_fmt (bps )
223
+ if format == "ogg" or format is None and ext == "ogg" :
224
+ sample_fmt = _get_flac_sample_fmt (bps )
225
+ return muxer , encoder , sample_fmt
194
226
195
227
196
228
# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
@@ -204,25 +236,27 @@ def save_audio(
204
236
bits_per_sample : Optional [int ] = None ,
205
237
buffer_size : int = 4096 ,
206
238
) -> None :
239
+ ext = None
207
240
if hasattr (uri , "write" ):
208
241
if format is None :
209
242
raise RuntimeError ("'format' is required when saving to file object." )
210
243
else :
211
244
uri = os .path .normpath (uri )
212
- s = StreamWriter (uri , format = format , buffer_size = buffer_size )
213
- if format is None :
214
- tokens = str (uri ).split ("." )
215
- if len (tokens ) > 1 :
216
- format = tokens [- 1 ].lower ()
245
+ if tokens := str (uri ).split ("." )[1 :]:
246
+ ext = tokens [- 1 ].lower ()
247
+
248
+ muxer , encoder , enc_fmt = _parse_save_args (ext , format , encoding , bits_per_sample )
217
249
218
250
if channels_first :
219
251
src = src .T
252
+
253
+ s = StreamWriter (uri , format = muxer , buffer_size = buffer_size )
220
254
s .add_audio_stream (
221
255
sample_rate ,
222
256
num_channels = src .size (- 1 ),
223
257
format = _get_sample_format (src .dtype ),
224
- encoder = _get_encoder ( src . dtype , format , encoding , bits_per_sample ) ,
225
- encoder_format = _get_encoder_format ( format , bits_per_sample ) ,
258
+ encoder = encoder ,
259
+ encoder_format = enc_fmt ,
226
260
)
227
261
with s .open ():
228
262
s .write_audio_chunk (0 , src )
0 commit comments