Fix return_complex warning on training (#1627)

* Fix return_complex warning on training

* remove unused prints
This commit is contained in:
Blaise 2023-12-22 02:35:51 +01:00 committed by GitHub
parent 0f8a5facd9
commit 78f03e7dc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -38,7 +38,6 @@ def spectral_de_normalize_torch(magnitudes):
mel_basis = {} mel_basis = {}
hann_window = {} hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram. """Convert waveform into Linear-frequency Linear-amplitude spectrogram.
@ -52,12 +51,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
Returns: Returns:
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
""" """
# Validation
if torch.min(y) < -1.07:
logger.debug("min value is %s", str(torch.min(y)))
if torch.max(y) > 1.07:
logger.debug("max value is %s", str(torch.max(y)))
# Window - Cache if needed # Window - Cache if needed
global hann_window global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
@ -66,7 +60,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device dtype=y.dtype, device=y.device
) )
# Padding # Padding
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
@ -74,7 +68,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
mode="reflect", mode="reflect",
) )
y = y.squeeze(1) y = y.squeeze(1)
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2) # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
spec = torch.stft( spec = torch.stft(
y, y,
@ -86,14 +80,13 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
pad_mode="reflect", pad_mode="reflect",
normalized=False, normalized=False,
onesided=True, onesided=True,
return_complex=False, return_complex=True,
) )
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame) # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
return spec return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
# MelBasis - Cache if needed # MelBasis - Cache if needed
global mel_basis global mel_basis