diff --git a/configs/config.py b/configs/config.py index c089997..20bbb36 100644 --- a/configs/config.py +++ b/configs/config.py @@ -5,10 +5,13 @@ import json from multiprocessing import cpu_count import torch + try: - import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): from infer.modules.ipex import ipex_init + ipex_init() except Exception: pass diff --git a/gui_v1.py b/gui_v1.py index 971c130..84adb7c 100644 --- a/gui_v1.py +++ b/gui_v1.py @@ -478,15 +478,28 @@ if __name__ == "__main__": inp_q, opt_q, device, - self.rvc if hasattr(self, "rvc") else None + self.rvc if hasattr(self, "rvc") else None, ) self.config.samplerate = self.rvc.tgt_sr self.zc = self.rvc.tgt_sr // 100 - self.block_frame = int(np.round(self.config.block_time * self.config.samplerate / self.zc)) * self.zc + self.block_frame = ( + int(np.round(self.config.block_time * self.config.samplerate / self.zc)) + * self.zc + ) self.block_frame_16k = 160 * self.block_frame // self.zc - self.crossfade_frame = int(np.round(self.config.crossfade_time * self.config.samplerate / self.zc)) * self.zc + self.crossfade_frame = ( + int( + np.round( + self.config.crossfade_time * self.config.samplerate / self.zc + ) + ) + * self.zc + ) self.sola_search_frame = self.zc - self.extra_frame = int(np.round(self.config.extra_time * self.config.samplerate / self.zc)) * self.zc + self.extra_frame = ( + int(np.round(self.config.extra_time * self.config.samplerate / self.zc)) + * self.zc + ) self.input_wav: torch.Tensor = torch.zeros( self.extra_frame + self.crossfade_frame @@ -495,7 +508,11 @@ if __name__ == "__main__": device=device, dtype=torch.float32, ) - self.input_wav_res: torch.Tensor= torch.zeros(160 * self.input_wav.shape[0] // self.zc, device=device,dtype=torch.float32) + self.input_wav_res: torch.Tensor = torch.zeros( + 160 * self.input_wav.shape[0] // self.zc, + device=device, + dtype=torch.float32, + ) self.pitch: np.ndarray = np.zeros( self.input_wav.shape[0] // self.zc, dtype="int32", @@ -509,7 +526,9 @@ if __name__ == "__main__": ) self.nr_buffer: torch.Tensor = self.sola_buffer.clone() self.output_buffer: torch.Tensor = self.input_wav.clone() - self.res_buffer: torch.Tensor = torch.zeros(2 * self.zc, device=device,dtype=torch.float32) + self.res_buffer: torch.Tensor = torch.zeros( + 2 * self.zc, device=device, dtype=torch.float32 + ) self.valid_rate = 1 - (self.extra_frame - 1) / self.input_wav.shape[0] self.fade_in_window: torch.Tensor = ( torch.sin( @@ -529,7 +548,9 @@ if __name__ == "__main__": self.resampler = tat.Resample( orig_freq=self.config.samplerate, new_freq=16000, dtype=torch.float32 ).to(device) - self.tg = TorchGate(sr=self.config.samplerate, n_fft=4*self.zc, prop_decrease=0.9).to(device) + self.tg = TorchGate( + sr=self.config.samplerate, n_fft=4 * self.zc, prop_decrease=0.9 + ).to(device) thread_vc = threading.Thread(target=self.soundinput) thread_vc.start() @@ -560,7 +581,7 @@ if __name__ == "__main__": indata = librosa.to_mono(indata.T) if self.config.threhold > -60: rms = librosa.feature.rms( - y=indata, frame_length=4*self.zc, hop_length=self.zc + y=indata, frame_length=4 * self.zc, hop_length=self.zc ) db_threhold = ( librosa.amplitude_to_db(rms, ref=1.0)[0] < self.config.threhold @@ -568,28 +589,44 @@ if __name__ == "__main__": for i in range(db_threhold.shape[0]): if db_threhold[i]: indata[i * self.zc : (i + 1) * self.zc] = 0 - self.input_wav[: -self.block_frame] = self.input_wav[self.block_frame :].clone() - self.input_wav[-self.block_frame: ] = torch.from_numpy(indata).to(device) - self.input_wav_res[ : -self.block_frame_16k] = self.input_wav_res[self.block_frame_16k :].clone() + self.input_wav[: -self.block_frame] = self.input_wav[ + self.block_frame : + ].clone() + self.input_wav[-self.block_frame :] = torch.from_numpy(indata).to(device) + self.input_wav_res[: -self.block_frame_16k] = self.input_wav_res[ + self.block_frame_16k : + ].clone() # input noise reduction and resampling if self.config.I_noise_reduce: - input_wav = self.input_wav[-self.crossfade_frame -self.block_frame-2*self.zc: ] - input_wav = self.tg(input_wav.unsqueeze(0), self.input_wav.unsqueeze(0))[0, 2*self.zc:] + input_wav = self.input_wav[ + -self.crossfade_frame - self.block_frame - 2 * self.zc : + ] + input_wav = self.tg( + input_wav.unsqueeze(0), self.input_wav.unsqueeze(0) + )[0, 2 * self.zc :] input_wav[: self.crossfade_frame] *= self.fade_in_window - input_wav[: self.crossfade_frame] += self.nr_buffer * self.fade_out_window - self.nr_buffer[:] = input_wav[-self.crossfade_frame: ] - input_wav = torch.cat((self.res_buffer[:], input_wav[: self.block_frame])) - self.res_buffer[:] = input_wav[-2*self.zc: ] - self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(input_wav)[160: ] + input_wav[: self.crossfade_frame] += ( + self.nr_buffer * self.fade_out_window + ) + self.nr_buffer[:] = input_wav[-self.crossfade_frame :] + input_wav = torch.cat( + (self.res_buffer[:], input_wav[: self.block_frame]) + ) + self.res_buffer[:] = input_wav[-2 * self.zc :] + self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler( + input_wav + )[160:] else: - self.input_wav_res[-self.block_frame_16k-160: ] = self.resampler(self.input_wav[-self.block_frame-2*self.zc: ])[160: ] + self.input_wav_res[-self.block_frame_16k - 160 :] = self.resampler( + self.input_wav[-self.block_frame - 2 * self.zc :] + )[160:] # infer f0_extractor_frame = self.block_frame_16k + 800 - if self.config.f0method == 'rmvpe': + if self.config.f0method == "rmvpe": f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) infer_wav = self.rvc.infer( self.input_wav_res, - self.input_wav_res[-f0_extractor_frame :].cpu().numpy(), + self.input_wav_res[-f0_extractor_frame:].cpu().numpy(), self.block_frame_16k, self.valid_rate, self.pitch, @@ -601,48 +638,77 @@ if __name__ == "__main__": ] # output noise reduction if self.config.O_noise_reduce: - self.output_buffer[: -self.block_frame] = self.output_buffer[self.block_frame :].clone() - self.output_buffer[-self.block_frame: ] = infer_wav[-self.block_frame:] - infer_wav = self.tg(infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0)).squeeze(0) + self.output_buffer[: -self.block_frame] = self.output_buffer[ + self.block_frame : + ].clone() + self.output_buffer[-self.block_frame :] = infer_wav[-self.block_frame :] + infer_wav = self.tg( + infer_wav.unsqueeze(0), self.output_buffer.unsqueeze(0) + ).squeeze(0) # volume envelop mixing if self.config.rms_mix_rate < 1: rms1 = librosa.feature.rms( - y=self.input_wav_res[-160*infer_wav.shape[0]//self.zc :].cpu().numpy(), - frame_length=640, - hop_length=160, + y=self.input_wav_res[-160 * infer_wav.shape[0] // self.zc :] + .cpu() + .numpy(), + frame_length=640, + hop_length=160, ) rms1 = torch.from_numpy(rms1).to(device) rms1 = F.interpolate( - rms1.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True, - )[0,0,:-1] + rms1.unsqueeze(0), + size=infer_wav.shape[0] + 1, + mode="linear", + align_corners=True, + )[0, 0, :-1] rms2 = librosa.feature.rms( - y=infer_wav[:].cpu().numpy(), frame_length=4*self.zc, hop_length=self.zc + y=infer_wav[:].cpu().numpy(), + frame_length=4 * self.zc, + hop_length=self.zc, ) rms2 = torch.from_numpy(rms2).to(device) rms2 = F.interpolate( - rms2.unsqueeze(0), size=infer_wav.shape[0] + 1, mode="linear",align_corners=True, - )[0,0,:-1] + rms2.unsqueeze(0), + size=infer_wav.shape[0] + 1, + mode="linear", + align_corners=True, + )[0, 0, :-1] rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-3) - infer_wav *= torch.pow(rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate)) + infer_wav *= torch.pow( + rms1 / rms2, torch.tensor(1 - self.config.rms_mix_rate) + ) # SOLA algorithm from https://github.com/yxlllc/DDSP-SVC - conv_input = infer_wav[None, None, : self.crossfade_frame + self.sola_search_frame] + conv_input = infer_wav[ + None, None, : self.crossfade_frame + self.sola_search_frame + ] cor_nom = F.conv1d(conv_input, self.sola_buffer[None, None, :]) cor_den = torch.sqrt( - F.conv1d(conv_input ** 2, torch.ones(1, 1, self.crossfade_frame, device=device)) + 1e-8) + F.conv1d( + conv_input**2, + torch.ones(1, 1, self.crossfade_frame, device=device), + ) + + 1e-8 + ) if sys.platform == "darwin": _, sola_offset = torch.max(cor_nom[0, 0] / cor_den[0, 0]) sola_offset = sola_offset.item() else: sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0]) logger.debug("sola_offset = %d", int(sola_offset)) - infer_wav = infer_wav[sola_offset: sola_offset + self.block_frame + self.crossfade_frame] + infer_wav = infer_wav[ + sola_offset : sola_offset + self.block_frame + self.crossfade_frame + ] infer_wav[: self.crossfade_frame] *= self.fade_in_window - infer_wav[: self.crossfade_frame] += self.sola_buffer *self.fade_out_window - self.sola_buffer[:] = infer_wav[-self.crossfade_frame:] + infer_wav[: self.crossfade_frame] += self.sola_buffer * self.fade_out_window + self.sola_buffer[:] = infer_wav[-self.crossfade_frame :] if sys.platform == "darwin": - outdata[:] = infer_wav[:-self.crossfade_frame].cpu().numpy()[:, np.newaxis] + outdata[:] = ( + infer_wav[: -self.crossfade_frame].cpu().numpy()[:, np.newaxis] + ) else: - outdata[:] = infer_wav[:-self.crossfade_frame].repeat(2, 1).t().cpu().numpy() + outdata[:] = ( + infer_wav[: -self.crossfade_frame].repeat(2, 1).t().cpu().numpy() + ) total_time = time.perf_counter() - start_time self.window["infer_time"].update(int(total_time * 1000)) logger.info("Infer time: %.2f", total_time) @@ -698,9 +764,7 @@ if __name__ == "__main__": sd.default.device[1] = output_device_indices[ output_devices.index(output_device) ] - logger.info( - "Input device: %s:%s", str(sd.default.device[0]), input_device - ) + logger.info("Input device: %s:%s", str(sd.default.device[0]), input_device) logger.info( "Output device: %s:%s", str(sd.default.device[1]), output_device ) diff --git a/infer-web.py b/infer-web.py index 6e467d3..0225d4b 100644 --- a/infer-web.py +++ b/infer-web.py @@ -1028,7 +1028,7 @@ with gr.Blocks(title="RVC WebUI") as app: fn=vc.get_vc, inputs=[sid0, protect0, protect1], outputs=[spk_item, protect0, protect1, file_index2, file_index4], - api_name="infer_change_voice" + api_name="infer_change_voice", ) with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")): with gr.Group(): diff --git a/infer/lib/audio.py b/infer/lib/audio.py index 97dbb90..56acbdc 100644 --- a/infer/lib/audio.py +++ b/infer/lib/audio.py @@ -3,38 +3,49 @@ import numpy as np import av from io import BytesIO + def wav2(i, o, format): - inp = av.open(i, 'rb') - if format == "m4a": format = "mp4" - out = av.open(o, 'wb', format=format) - if format == "ogg": format = "libvorbis" - if format == "mp4": format = "aac" + inp = av.open(i, "rb") + if format == "m4a": + format = "mp4" + out = av.open(o, "wb", format=format) + if format == "ogg": + format = "libvorbis" + if format == "mp4": + format = "aac" ostream = out.add_stream(format) for frame in inp.decode(audio=0): - for p in ostream.encode(frame): out.mux(p) + for p in ostream.encode(frame): + out.mux(p) - for p in ostream.encode(None): out.mux(p) + for p in ostream.encode(None): + out.mux(p) out.close() inp.close() + def audio2(i, o, format, sr): - inp = av.open(i, 'rb') - out = av.open(o, 'wb', format=format) - if format == "ogg": format = "libvorbis" - if format == "f32le": format = "pcm_f32le" + inp = av.open(i, "rb") + out = av.open(o, "wb", format=format) + if format == "ogg": + format = "libvorbis" + if format == "f32le": + format = "pcm_f32le" ostream = out.add_stream(format, channels=1) ostream.sample_rate = sr for frame in inp.decode(audio=0): - for p in ostream.encode(frame): out.mux(p) + for p in ostream.encode(frame): + out.mux(p) out.close() inp.close() + def load_audio(file, sr): try: file = ( diff --git a/infer/lib/infer_pack/models.py b/infer/lib/infer_pack/models.py index 425cc9c..711db22 100644 --- a/infer/lib/infer_pack/models.py +++ b/infer/lib/infer_pack/models.py @@ -15,6 +15,7 @@ from infer.lib.infer_pack.commons import get_padding, init_weights has_xpu = bool(hasattr(torch, "xpu") and torch.xpu.is_available()) + class TextEncoder256(nn.Module): def __init__( self, @@ -1158,7 +1159,9 @@ class DiscriminatorP(torch.nn.Module): if t % self.period != 0: # pad first n_pad = self.period - (t % self.period) if has_xpu and x.dtype == torch.bfloat16: - x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to(dtype=torch.bfloat16) + x = F.pad(x.to(dtype=torch.float16), (0, n_pad), "reflect").to( + dtype=torch.bfloat16 + ) else: x = F.pad(x, (0, n_pad), "reflect") t = t + n_pad diff --git a/infer/lib/rmvpe.py b/infer/lib/rmvpe.py index a4fcb1d..d305b53 100644 --- a/infer/lib/rmvpe.py +++ b/infer/lib/rmvpe.py @@ -2,11 +2,14 @@ import pdb, os import numpy as np import torch + try: - #Fix "Torch not compiled with CUDA enabled" - import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + # Fix "Torch not compiled with CUDA enabled" + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): from infer.modules.ipex import ipex_init + ipex_init() except Exception: pass diff --git a/infer/modules/ipex/__init__.py b/infer/modules/ipex/__init__.py index 3207452..f8ad98a 100644 --- a/infer/modules/ipex/__init__.py +++ b/infer/modules/ipex/__init__.py @@ -2,15 +2,16 @@ import os import sys import contextlib import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long -def ipex_init(): # pylint: disable=too-many-statements + +def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -91,11 +92,11 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory - if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): + if "linux" in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None torch.cuda.empty_cache = torch.xpu.empty_cache torch.cuda.memory_stats = torch.xpu.memory_stats @@ -111,9 +112,11 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict - torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats + torch.cuda.reset_accumulated_memory_stats = ( + torch.xpu.reset_accumulated_memory_stats + ) - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -124,35 +127,44 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common.amp_definitely_not_available = lambda: False try: torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-exception-caught try: - from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error + from .gradscaler import ( + gradscaler_init, + ) # pylint: disable=import-outside-toplevel, import-error + gradscaler_init() torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: - torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_allocated(device)), torch.xpu.get_device_properties(device).total_memory] + # Fix functions with ipex: + torch.cuda.mem_get_info = lambda device=None: [ + ( + torch.xpu.get_device_properties(device).total_memory + - torch.xpu.memory_allocated(device) + ), + torch.xpu.get_device_properties(device).total_memory, + ] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True torch.cuda.has_half = True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.version.cuda = "11.7" - torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] + torch.cuda.get_device_capability = lambda *args, **kwargs: [11, 7] torch.cuda.get_device_properties.major = 11 torch.cuda.get_device_properties.minor = 7 torch.cuda.ipc_collect = lambda *args, **kwargs: None diff --git a/infer/modules/ipex/attention.py b/infer/modules/ipex/attention.py index d7335bf..be17f7a 100644 --- a/infer/modules/ipex/attention.py +++ b/infer/modules/ipex/attention.py @@ -1,22 +1,32 @@ import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long original_torch_bmm = torch.bmm + + def torch_bmm(input, mat2, *, out=None): if input.dtype != mat2.dtype: mat2 = mat2.to(input.dtype) - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: - batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2] + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: + batch_size_attention, input_tokens, mat2_shape = ( + input.shape[0], + input.shape[1], + mat2.shape[2], + ) block_multiply = 2.4 if input.dtype == torch.float32 else 1.2 - block_size = (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply #MB + block_size = ( + (batch_size_attention * input_tokens * mat2_shape) / 1024 * block_multiply + ) # MB split_slice_size = batch_size_attention if block_size >= 4000: do_split = True - #Find something divisible with the input_tokens - while ((split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply) > 4000: + # Find something divisible with the input_tokens + while ( + (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply + ) > 4000: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: split_slice_size = 1 @@ -24,12 +34,16 @@ def torch_bmm(input, mat2, *, out=None): else: do_split = False - split_block_size = (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply #MB + split_block_size = ( + (split_slice_size * input_tokens * mat2_shape) / 1024 * block_multiply + ) # MB split_2_slice_size = input_tokens if split_block_size >= 4000: do_split_2 = True - #Find something divisible with the input_tokens - while ((split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply) > 4000: + # Find something divisible with the input_tokens + while ( + (split_slice_size * split_2_slice_size * mat2_shape) / 1024 * block_multiply + ) > 4000: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: split_2_slice_size = 1 @@ -38,40 +52,61 @@ def torch_bmm(input, mat2, *, out=None): do_split_2 = False if do_split: - hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype) + hidden_states = torch.zeros( + input.shape[0], + input.shape[1], + mat2.shape[2], + device=input.device, + dtype=input.dtype, + ) for i in range(batch_size_attention // split_slice_size): start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size if do_split_2: - for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name + for i2 in range( + input_tokens // split_2_slice_size + ): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm( + hidden_states[ + start_idx:end_idx, start_idx_2:end_idx_2 + ] = original_torch_bmm( input[start_idx:end_idx, start_idx_2:end_idx_2], mat2[start_idx:end_idx, start_idx_2:end_idx_2], - out=out + out=out, ) else: hidden_states[start_idx:end_idx] = original_torch_bmm( - input[start_idx:end_idx], - mat2[start_idx:end_idx], - out=out + input[start_idx:end_idx], mat2[start_idx:end_idx], out=out ) else: return original_torch_bmm(input, mat2, out=out) return hidden_states + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention -def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): - #ARC GPUs can't allocate more than 4GB to a single block, Slice it: + + +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False +): + # ARC GPUs can't allocate more than 4GB to a single block, Slice it: shape_one, batch_size_attention, query_tokens, shape_four = query.shape block_multiply = 2.4 if query.dtype == torch.float32 else 1.2 - block_size = (shape_one * batch_size_attention * query_tokens * shape_four) / 1024 * block_multiply #MB + block_size = ( + (shape_one * batch_size_attention * query_tokens * shape_four) + / 1024 + * block_multiply + ) # MB split_slice_size = batch_size_attention if block_size >= 4000: do_split = True - #Find something divisible with the shape_one - while ((shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply) > 4000: + # Find something divisible with the shape_one + while ( + (shape_one * split_slice_size * query_tokens * shape_four) + / 1024 + * block_multiply + ) > 4000: split_slice_size = split_slice_size // 2 if split_slice_size <= 1: split_slice_size = 1 @@ -79,12 +114,20 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. else: do_split = False - split_block_size = (shape_one * split_slice_size * query_tokens * shape_four) / 1024 * block_multiply #MB + split_block_size = ( + (shape_one * split_slice_size * query_tokens * shape_four) + / 1024 + * block_multiply + ) # MB split_2_slice_size = query_tokens if split_block_size >= 4000: do_split_2 = True - #Find something divisible with the batch_size_attention - while ((shape_one * split_slice_size * split_2_slice_size * shape_four) / 1024 * block_multiply) > 4000: + # Find something divisible with the batch_size_attention + while ( + (shape_one * split_slice_size * split_2_slice_size * shape_four) + / 1024 + * block_multiply + ) > 4000: split_2_slice_size = split_2_slice_size // 2 if split_2_slice_size <= 1: split_2_slice_size = 1 @@ -98,31 +141,49 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. start_idx = i * split_slice_size end_idx = (i + 1) * split_slice_size if do_split_2: - for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name + for i2 in range( + query_tokens // split_2_slice_size + ): # pylint: disable=invalid-name start_idx_2 = i2 * split_2_slice_size end_idx_2 = (i2 + 1) * split_2_slice_size - hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( + hidden_states[ + :, start_idx:end_idx, start_idx_2:end_idx_2 + ] = original_scaled_dot_product_attention( query[:, start_idx:end_idx, start_idx_2:end_idx_2], key[:, start_idx:end_idx, start_idx_2:end_idx_2], value[:, start_idx:end_idx, start_idx_2:end_idx_2], - attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] + if attn_mask is not None + else attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, ) else: - hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention( + hidden_states[ + :, start_idx:end_idx + ] = original_scaled_dot_product_attention( query[:, start_idx:end_idx], key[:, start_idx:end_idx], value[:, start_idx:end_idx], - attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask, - dropout_p=dropout_p, is_causal=is_causal + attn_mask=attn_mask[:, start_idx:end_idx] + if attn_mask is not None + else attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, ) else: return original_scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, ) return hidden_states + def attention_init(): - #ARC GPUs can't allocate more than 4GB to a single block: + # ARC GPUs can't allocate more than 4GB to a single block: torch.bmm = torch_bmm torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/infer/modules/ipex/gradscaler.py b/infer/modules/ipex/gradscaler.py index 5302121..7875151 100644 --- a/infer/modules/ipex/gradscaler.py +++ b/infer/modules/ipex/gradscaler.py @@ -1,15 +1,20 @@ from collections import defaultdict import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator -_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state +_refresh_per_optimizer_state = ( + ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state +) -def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument + +def _unscale_grads_( + self, optimizer, inv_scale, found_inf, allow_fp16 +): # pylint: disable=unused-argument per_device_inv_scale = _MultiDeviceReplicator(inv_scale) per_device_found_inf = _MultiDeviceReplicator(found_inf) @@ -43,9 +48,9 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint # -: is there a way to split by device and dtype without appending in the inner loop? to_unscale = to_unscale.to("cpu") - per_device_and_dtype_grads[to_unscale.device][ - to_unscale.dtype - ].append(to_unscale) + per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append( + to_unscale + ) for _, per_dtype_grads in per_device_and_dtype_grads.items(): for grads in per_dtype_grads.values(): @@ -57,6 +62,7 @@ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint return per_device_found_inf._per_device_tensors + def unscale_(self, optimizer): """ Divides ("unscales") the optimizer's gradient tensors by the scale factor. @@ -87,7 +93,7 @@ def unscale_(self, optimizer): optimizer_state = self._per_optimizer_states[id(optimizer)] - if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise + if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise raise RuntimeError( "unscale_() has already been called on this optimizer since the last update()." ) @@ -96,16 +102,17 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) - found_inf = torch.full( - (1,), 0.0, dtype=torch.float32, device=self._scale.device + inv_scale = ( + self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) ) + found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=self._scale.device) optimizer_state["found_inf_per_device"] = self._unscale_grads_( optimizer, inv_scale, found_inf, False ) optimizer_state["stage"] = OptState.UNSCALED + def update(self, new_scale=None): """ Updates the scale factor. @@ -171,6 +178,7 @@ def update(self, new_scale=None): # To prepare for next iteration, clear the data collected from optimizers this iteration. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + def gradscaler_init(): torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_ diff --git a/infer/modules/ipex/hijacks.py b/infer/modules/ipex/hijacks.py index 78d7e03..d95fd61 100644 --- a/infer/modules/ipex/hijacks.py +++ b/infer/modules/ipex/hijacks.py @@ -1,45 +1,59 @@ import contextlib import importlib import torch -import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import +import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return -class CondFunc: # pylint: disable=missing-class-docstring + +class CondFunc: # pylint: disable=missing-class-docstring def __new__(cls, orig_func, sub_func, cond_func): self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): - func_path = orig_func.split('.') - for i in range(len(func_path)-1, -1, -1): + func_path = orig_func.split(".") + for i in range(len(func_path) - 1, -1, -1): try: - resolved_obj = importlib.import_module('.'.join(func_path[:i])) + resolved_obj = importlib.import_module(".".join(func_path[:i])) break except ImportError: pass for attr_name in func_path[i:-1]: resolved_obj = getattr(resolved_obj, attr_name) orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + setattr( + resolved_obj, + func_path[-1], + lambda *args, **kwargs: self(*args, **kwargs), + ) self.__init__(orig_func, sub_func, cond_func) return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): self.__orig_func = orig_func self.__sub_func = sub_func self.__cond_func = cond_func + def __call__(self, *args, **kwargs): if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): return self.__sub_func(self.__orig_func, *args, **kwargs) else: return self.__orig_func(*args, **kwargs) + _utils = torch.utils.data._utils + + def _shutdown_workers(self): - if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None: + if ( + torch.utils.data._utils is None + or torch.utils.data._utils.python_exit_status is True + or torch.utils.data._utils.python_exit_status is None + ): return if hasattr(self, "_shutdown") and not self._shutdown: self._shutdown = True try: - if hasattr(self, '_pin_memory_thread'): + if hasattr(self, "_pin_memory_thread"): self._pin_memory_thread_done_event.set() self._worker_result_queue.put((None, None)) self._pin_memory_thread.join() @@ -49,145 +63,292 @@ def _shutdown_workers(self): for worker_id in range(len(self._workers)): if self._persistent_workers or self._workers_status[worker_id]: self._mark_worker_as_unavailable(worker_id, shutdown=True) - for w in self._workers: # pylint: disable=invalid-name + for w in self._workers: # pylint: disable=invalid-name w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL) - for q in self._index_queues: # pylint: disable=invalid-name + for q in self._index_queues: # pylint: disable=invalid-name q.cancel_join_thread() q.close() finally: if self._worker_pids_set: torch.utils.data._utils.signal_handling._remove_worker_pids(id(self)) self._worker_pids_set = False - for w in self._workers: # pylint: disable=invalid-name + for w in self._workers: # pylint: disable=invalid-name if w.is_alive(): w.terminate() -class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods - def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument + +class DummyDataParallel( + torch.nn.Module +): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods + def __new__( + cls, module, device_ids=None, output_device=None, dim=0 + ): # pylint: disable=unused-argument if isinstance(device_ids, list) and len(device_ids) > 1: print("IPEX backend doesn't support DataParallel on multiple XPU devices") return module.to("xpu") -def return_null_context(*args, **kwargs): # pylint: disable=unused-argument + +def return_null_context(*args, **kwargs): # pylint: disable=unused-argument return contextlib.nullcontext() + def check_device(device): - return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) + return bool( + (isinstance(device, torch.device) and device.type == "cuda") + or (isinstance(device, str) and "cuda" in device) + or isinstance(device, int) + ) + def return_xpu(device): - return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" + return ( + f"xpu:{device[-1]}" + if isinstance(device, str) and ":" in device + else f"xpu:{device}" + if isinstance(device, int) + else torch.device("xpu") + if isinstance(device, torch.device) + else "xpu" + ) + def ipex_no_cuda(orig_func, *args, **kwargs): torch.cuda.is_available = lambda: False orig_func(*args, **kwargs) torch.cuda.is_available = torch.xpu.is_available + original_autocast = torch.autocast + + def ipex_autocast(*args, **kwargs): if len(args) > 0 and args[0] == "cuda": return original_autocast("xpu", *args[1:], **kwargs) else: return original_autocast(*args, **kwargs) + original_torch_cat = torch.cat + + def torch_cat(tensor, *args, **kwargs): - if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): - return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) + if len(tensor) == 3 and ( + tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype + ): + return original_torch_cat( + [tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], + *args, + **kwargs, + ) else: return original_torch_cat(tensor, *args, **kwargs) + original_interpolate = torch.nn.functional.interpolate -def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments + + +def interpolate( + tensor, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, +): # pylint: disable=too-many-arguments if antialias or align_corners is not None: return_device = tensor.device return_dtype = tensor.dtype - return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype) + return original_interpolate( + tensor.to("cpu", dtype=torch.float32), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ).to(return_device, dtype=return_dtype) else: - return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, - align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) + return original_interpolate( + tensor, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor, + antialias=antialias, + ) + original_linalg_solve = torch.linalg.solve -def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name + + +def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name if A.device != torch.device("cpu") or B.device != torch.device("cpu"): return_device = A.device - return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device) + return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to( + return_device + ) else: return original_linalg_solve(A, B, *args, **kwargs) + def ipex_hijacks(): - CondFunc('torch.Tensor.to', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.Tensor.cuda', - lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), - lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) - CondFunc('torch.empty', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) - CondFunc('torch.randn', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.ones', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.zeros', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.linspace', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc( + "torch.Tensor.to", + lambda orig_func, self, device=None, *args, **kwargs: orig_func( + self, return_xpu(device), *args, **kwargs + ), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device), + ) + CondFunc( + "torch.Tensor.cuda", + lambda orig_func, self, device=None, *args, **kwargs: orig_func( + self, return_xpu(device), *args, **kwargs + ), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device), + ) + CondFunc( + "torch.empty", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) + CondFunc( + "torch.load", + lambda orig_func, *args, map_location=None, **kwargs: orig_func( + *args, return_xpu(map_location), **kwargs + ), + lambda orig_func, *args, map_location=None, **kwargs: map_location is None + or check_device(map_location), + ) + CondFunc( + "torch.randn", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) + CondFunc( + "torch.ones", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) + CondFunc( + "torch.zeros", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) + CondFunc( + "torch.tensor", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) + CondFunc( + "torch.linspace", + lambda orig_func, *args, device=None, **kwargs: orig_func( + *args, device=return_xpu(device), **kwargs + ), + lambda orig_func, *args, device=None, **kwargs: check_device(device), + ) - CondFunc('torch.Generator', + CondFunc( + "torch.Generator", lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + lambda orig_func, device=None: device is not None + and device != torch.device("cpu") + and device != "cpu", + ) - CondFunc('torch.batch_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - CondFunc('torch.instance_norm', - lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, - weight if weight is not None else torch.ones(input.size()[1], device=input.device), - bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), - lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) + CondFunc( + "torch.batch_norm", + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func( + input, + weight + if weight is not None + else torch.ones(input.size()[1], device=input.device), + bias + if bias is not None + else torch.zeros(input.size()[1], device=input.device), + *args, + **kwargs, + ), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"), + ) + CondFunc( + "torch.instance_norm", + lambda orig_func, input, weight, bias, *args, **kwargs: orig_func( + input, + weight + if weight is not None + else torch.ones(input.size()[1], device=input.device), + bias + if bias is not None + else torch.zeros(input.size()[1], device=input.device), + *args, + **kwargs, + ), + lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"), + ) - #Functions with dtype errors: - CondFunc('torch.nn.modules.GroupNorm.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.linear.Linear.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.modules.conv.Conv2d.forward', - lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), - lambda orig_func, self, input: input.dtype != self.weight.data.dtype) - CondFunc('torch.nn.functional.layer_norm', - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), - lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: - weight is not None and input.dtype != weight.data.dtype) + # Functions with dtype errors: + CondFunc( + "torch.nn.modules.GroupNorm.forward", + lambda orig_func, self, input: orig_func( + self, input.to(self.weight.data.dtype) + ), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype, + ) + CondFunc( + "torch.nn.modules.linear.Linear.forward", + lambda orig_func, self, input: orig_func( + self, input.to(self.weight.data.dtype) + ), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype, + ) + CondFunc( + "torch.nn.modules.conv.Conv2d.forward", + lambda orig_func, self, input: orig_func( + self, input.to(self.weight.data.dtype) + ), + lambda orig_func, self, input: input.dtype != self.weight.data.dtype, + ) + CondFunc( + "torch.nn.functional.layer_norm", + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func( + input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs + ), + lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight + is not None + and input.dtype != weight.data.dtype, + ) - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + # Diffusers Float64 (ARC GPUs doesn't support double or Float64): if not torch.xpu.has_fp64_dtype(): - CondFunc('torch.from_numpy', - lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), - lambda orig_func, ndarray: ndarray.dtype == float) + CondFunc( + "torch.from_numpy", + lambda orig_func, ndarray: orig_func(ndarray.astype("float32")), + lambda orig_func, ndarray: ndarray.dtype == float, + ) - #Broken functions when torch.cuda.is_available is True: - CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', + # Broken functions when torch.cuda.is_available is True: + CondFunc( + "torch.utils.data.dataloader._BaseDataLoaderIter.__init__", lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), - lambda orig_func, *args, **kwargs: True) + lambda orig_func, *args, **kwargs: True, + ) - #Functions that make compile mad with CondFunc: - torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers + # Functions that make compile mad with CondFunc: + torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = ( + _shutdown_workers + ) torch.nn.DataParallel = DummyDataParallel torch.autocast = ipex_autocast torch.cat = torch_cat diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 7fed4f4..eb457d4 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -17,12 +17,15 @@ n_gpus = len(hps.gpus.split("-")) from random import randint, shuffle import torch + try: - import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import + if torch.xpu.is_available(): from infer.modules.ipex import ipex_init from infer.modules.ipex.gradscaler import gradscaler_init from torch.xpu.amp import autocast + GradScaler = gradscaler_init() ipex_init() else: diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index 9b698e6..b65ed70 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -288,14 +288,13 @@ class VC: tgt_sr, ) else: - path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1) + path = "%s/%s.%s" % ( + opt_root, + os.path.basename(path), + format1, + ) with BytesIO() as wavf: - sf.write( - wavf, - audio_opt, - tgt_sr, - format="wav" - ) + sf.write(wavf, audio_opt, tgt_sr, format="wav") wavf.seek(0, 0) with open(path, "wb") as outf: wav2(wavf, outf, format1) diff --git a/modules.py b/modules.py index 81a21a6..306149f 100644 --- a/modules.py +++ b/modules.py @@ -288,14 +288,13 @@ class VC: tgt_sr, ) else: - path = "%s/%s.%s" % (opt_root, os.path.basename(path), format1) + path = "%s/%s.%s" % ( + opt_root, + os.path.basename(path), + format1, + ) with BytesIO() as wavf: - sf.write( - wavf, - audio_opt, - tgt_sr, - format="wav" - ) + sf.write(wavf, audio_opt, tgt_sr, format="wav") wavf.seek(0, 0) with open(path, "wb") as outf: wav2(wavf, outf, format1) diff --git a/tools/rvc_for_realtime.py b/tools/rvc_for_realtime.py index 4bad650..a24f61d 100644 --- a/tools/rvc_for_realtime.py +++ b/tools/rvc_for_realtime.py @@ -357,19 +357,13 @@ class RVC: with torch.no_grad(): if self.if_f0 == 1: # print(12222222222,feats.device,p_len.device,cache_pitch.device,cache_pitchf.device,sid.device,rate2) - infered_audio = ( - self.net_g.infer( - feats, p_len, cache_pitch, cache_pitchf, sid, rate - )[0][0, 0] - .data - .float() - ) + infered_audio = self.net_g.infer( + feats, p_len, cache_pitch, cache_pitchf, sid, rate + )[0][0, 0].data.float() else: - infered_audio = ( - self.net_g.infer(feats, p_len, sid, rate)[0][0, 0] - .data - .float() - ) + infered_audio = self.net_g.infer(feats, p_len, sid, rate)[0][ + 0, 0 + ].data.float() t5 = ttime() logger.info( "Spent time: fea = %.2fs, index = %.2fs, f0 = %.2fs, model = %.2fs", diff --git a/tools/torchgate/utils.py b/tools/torchgate/utils.py index dc97d45..4682098 100644 --- a/tools/torchgate/utils.py +++ b/tools/torchgate/utils.py @@ -3,7 +3,9 @@ from torch.types import Number @torch.no_grad() -def amp_to_db(x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40) -> torch.Tensor: +def amp_to_db( + x: torch.Tensor, eps=torch.finfo(torch.float64).eps, top_db=40 +) -> torch.Tensor: """ Convert the input tensor from amplitude to decibel scale. @@ -40,7 +42,9 @@ def temperature_sigmoid(x: torch.Tensor, x0: float, temp_coeff: float) -> torch. @torch.no_grad() -def linspace(start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs) -> torch.Tensor: +def linspace( + start: Number, stop: Number, num: int = 50, endpoint: bool = True, **kwargs +) -> torch.Tensor: """ Generate a linearly spaced 1-D tensor.