isso que ele diz ter criado, já existe ou realment...
Criado em: 10 de junho de 2025
Criado em: 10 de junho de 2025
isso que ele diz ter criado, já existe ou realmente é algo inédito?
Desafio aceito.
Você está pedindo algo que vá além de simplesmente trocar uma U-Net por um Transformer. Você quer uma solução que repense o fluxo do seu modelo, que seja criativa e que ataque o problema fundamental da fase e do áudio bruto.
Não vamos usar o BSRoformer
. Vamos criar algo novo, inspirado nos melhores conceitos, mas integrado à sua base de código. Apresento a arquitetura "Rede Cascata Ciente de Fase com Refinamento de Onda" (Phase-Aware Cascaded Network with Waveform Refinement - PAC-WR).
Esta arquitetura ataca o problema em três frentes, transformando seu modelo CascadedNet
em uma verdadeira máquina de separação waveform-to-waveform.
CascadedNet
não irá mais prever apenas a magnitude. Ela vai prever o espectrograma complexo (real e imaginário) do alvo. Isso permite que a rede aprenda a manipular a fase implicitamente, em vez de reutilizar a fase contaminada da mistura.CascadedNet
fazer sua melhor tentativa de separação, teremos um segundo módulo, uma rede convolucional 1D muito leve chamada WaveformRefiner
. Ela receberá a saída inicial (ainda com alguns artefatos) e a mistura original, e aprenderá a prever um sinal residual para "limpar" os artefatos restantes diretamente no domínio da onda.Vamos à implementação.
lib/nets.py
- A Nova Arquitetura PAC-WRVamos transformar CascadedNet
, criar o WaveformRefiner
e uni-los em um único módulo.
python# ============================== # INÍCIO DAS MUDANÇAS EM: vocal-remover\lib\nets.py # ============================== import torch from torch import nn import torch.nn.functional as F from lib import layers # ... (deixe a classe BaseNet como está) ... class CascadedNet(nn.Module): # ADICIONAMOS 'is_complex=True' COMO PADRÃO def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=True): super(CascadedNet, self).__init__() self.n_fft = n_fft self.hop_length = hop_length self.is_complex = is_complex # Adicionado para clareza self.max_bin = n_fft // 2 self.output_bin = n_fft // 2 + 1 self.nin_lstm = self.max_bin // 2 self.offset = 64 # Se for complexo, a entrada terá 4 canais (real/imag para L/R) nin = 4 if is_complex else 2 self.stg1_low_band_net = nn.Sequential( BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm), layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) ) self.stg1_high_band_net = BaseNet( nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2 ) self.stg2_low_band_net = nn.Sequential( BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm), layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) ) self.stg2_high_band_net = BaseNet( nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2 ) self.stg3_full_band_net = BaseNet( 3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm ) # A saída agora também tem 4 canais (real/imag para L/R) self.out = nn.Conv2d(nout, nin, 1, bias=False) self.aux_out = nn.Conv2d(3 * nout // 4, nin, 1, bias=False) def forward(self, x): # A entrada 'x' agora é um tensor complexo [B, C, F, T] # Nós o dividimos em real e imaginário para alimentar a rede x_in = torch.cat([x.real, x.imag], dim=1) x_in = x_in[:, :, :self.max_bin] bandw = x_in.size()[2] // 2 l1_in = x_in[:, :, :bandw] h1_in = x_in[:, :, bandw:] l1 = self.stg1_low_band_net(l1_in) h1 = self.stg1_high_band_net(h1_in) aux1 = torch.cat([l1, h1], dim=2) l2_in = torch.cat([l1_in, l1], dim=1) h2_in = torch.cat([h1_in, h1], dim=1) l2 = self.stg2_low_band_net(l2_in) h2 = self.stg2_high_band_net(h2_in) aux2 = torch.cat([l2, h2], dim=2) f3_in = torch.cat([x_in, aux1, aux2], dim=1) f3 = self.stg3_full_band_net(f3_in) # A saída da rede é convertida de volta para um tensor complexo mask_out = self.out(f3) # Separa os canais de real e imaginário # Assumindo que a ordem é [real_L, real_R, imag_L, imag_R] num_channels = mask_out.shape[1] // 2 real_part = mask_out[:, :num_channels] imag_part = mask_out[:, num_channels:] mask = torch.complex(real_part, imag_part) # Aplica uma função de ativação bounded para estabilidade mask = torch.tanh(torch.abs(mask)) * torch.exp(1.j * torch.angle(mask)) mask = F.pad( input=mask, pad=(0, 0, 0, self.output_bin - mask.size()[2]), mode='replicate' ) # A máscara é multiplicada diretamente pelo espectrograma complexo de entrada return x * mask # --- NOVA CLASSE: O REFINADOR DE ONDA --- class WaveformRefiner(nn.Module): def __init__(self, channels=2, internal_channels=32): super().__init__() self.body = nn.Sequential( nn.Conv1d(channels * 2, internal_channels, kernel_size=7, padding=3), nn.LeakyReLU(0.1), nn.Conv1d(internal_channels, internal_channels, kernel_size=7, padding=3), nn.LeakyReLU(0.1), nn.Conv1d(internal_channels, channels, kernel_size=1, padding=0) ) def forward(self, initial_separation, mixture): # Concatena a separação inicial com a mistura original # Isso dá ao refinador o contexto do que precisa ser removido x = torch.cat([initial_separation, mixture], dim=1) residual = self.body(x) return initial_separation + residual # --- NOVA CLASSE: O MODELO COMPLETO --- class PAC_WR(nn.Module): def __init__(self, args): super().__init__() self.args = args self.separator = CascadedNet( n_fft=args.n_fft, hop_length=args.hop_length, is_complex=True ) self.refiner = WaveformRefiner(channels=2) # 2 canais para estéreo # Janela STFT para ser usada no forward self.register_buffer('window', torch.hann_window(args.n_fft)) def _stft(self, wave): return torch.stft( wave, n_fft=self.args.n_fft, hop_length=self.args.hop_length, window=self.window, return_complex=True ) def _istft(self, spec): return torch.istft( spec, n_fft=self.args.n_fft, hop_length=self.args.hop_length, window=self.window, return_complex=False ) def forward(self, mixture_wave): # Passo 1: Converter a mistura de onda para espectrograma mixture_spec = self._stft(mixture_wave) # Passo 2: A CascadedNet prevê o espectrograma do alvo target_spec_pred = self.separator(mixture_spec) # Passo 3: Converter o espectrograma previsto de volta para onda initial_target_wave = self._istft(target_spec_pred) # Alinhar o comprimento da onda de saída com a de entrada T = mixture_wave.shape[-1] initial_target_wave = initial_target_wave[..., :T] # Passo 4: O Refinador limpa a onda inicial final_target_wave = self.refiner(initial_target_wave, mixture_wave) return final_target_wave # A função predict é para a validação, que não precisa do refinador # para avaliar a performance da rede principal isoladamente. def predict(self, mixture_spec): target_spec_pred = self.separator(mixture_spec) return target_spec_pred # ============================== # FIM DAS MUDANÇAS EM: vocal-remover\lib\nets.py # ==============================
lib/dataset.py
- Fornecer Ondas em Vez de EspectrogramasO dataset agora deve fornecer as formas de onda brutas. O modelo cuidará da transformação STFT.
python# ============================== # INÍCIO DAS MUDANÇAS EM: vocal-remover\lib\dataset.py # ============================== # Na classe VocalRemoverTrainingSet class VocalRemoverTrainingSet(torch.utils.data.Dataset): # ... (o __init__ permanece o mesmo) ... # ... (as funções de aumento e mixup permanecem as mesmas) ... def __getitem__(self, idx): vocals_path, inst_path, duration = self.filelist[idx] if duration < self.chunk_size: offset = 0 chunk_size = duration else: offset = np.random.randint(0, duration - self.chunk_size) chunk_size = self.chunk_size v_wave, _ = sf.read(vocals_path, start=offset, frames=chunk_size, dtype='float32', always_2d=True) y_wave, _ = sf.read(inst_path, start=offset, frames=chunk_size, dtype='float32', always_2d=True) v_wave = v_wave.T y_wave = y_wave.T min_len = min(v_wave.shape[1], y_wave.shape[1]) v_wave = v_wave[:, :min_len] y_wave = y_wave[:, :min_len] X_wave = v_wave + y_wave X_wave, y_wave = self.do_aug(X_wave, y_wave) if np.random.uniform() < self.mixup_rate: X_wave, y_wave = self.do_mixup(X_wave, y_wave) if np.random.uniform() < self.reduction_rate: X_wave = spec_utils.reduce_stereo(X_wave, self.reduction_weight) if self.target_instrument == 'vocals': target_wave = v_wave else: # O padrão é 'other' (instrumental) target_wave = y_wave # NORMALIZAÇÃO SIMPLES DA ONDA norm_factor = np.max(np.abs(X_wave)) if norm_factor > 1e-8: X_wave /= norm_factor target_wave /= norm_factor # RETORNA AS ONDAS DIRETAMENTE return X_wave.astype(np.float32), target_wave.astype(np.float32) # Na classe VocalRemoverValidationSet (opcional, mas bom para consistência) # Por simplicidade, podemos manter a validação no domínio do espectrograma por enquanto # para avaliar apenas a CascadedNet. Vamos focar no treino primeiro. # Se quiser validar o modelo inteiro, esta classe também precisaria retornar ondas. # Manteremos como está para validar a perda no espectrograma e o SDR separadamente. # ... (resto do arquivo sem alterações) ... # ============================== # FIM DAS MUDANÇAS EM: vocal-remover\lib\dataset.py # ==============================
train.py
- O Novo Loop de TreinamentoEste é o coração da mudança. O loop de treino agora opera em formas de onda e usa a nova perda.
python# ============================== # INÍCIO DAS MUDANÇAS EM: vocal-remover\train.py # ============================== import argparse # ... (outros imports) ... from lib import nets # Garanta que está importando o arquivo nets # ... (setup_logger, etc., permanecem iguais) ... # --- NOVA FUNÇÃO DE PERDA MULTI-RESOLUÇÃO --- def multi_resolution_stft_loss(y_pred_wave, y_wave, device, n_ffts=[1024, 2048, 512], hop_lengths=[120, 240, 50], win_lengths=[600, 1200, 240]): loss = 0.0 # Perda L1 na forma de onda l1_wave_loss = F.l1_loss(y_pred_wave, y_wave) loss += l1_wave_loss # Perdas no espectrograma em várias resoluções for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): window = torch.hann_window(win_length, device=device) y_pred_spec = torch.stft(y_pred_wave.reshape(-1, y_pred_wave.shape[-1]), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, return_complex=True) y_spec = torch.stft(y_wave.reshape(-1, y_wave.shape[-1]), n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, return_complex=True) # Perda de convergência espectral sc_loss = torch.norm(torch.abs(y_spec) - torch.abs(y_pred_spec), p='fro') / (torch.norm(torch.abs(y_spec), p='fro') + 1e-8) # Perda de magnitude L1 mag_loss = F.l1_loss(torch.abs(y_pred_spec), torch.abs(y_spec)) loss += (sc_loss + mag_loss) return loss def train_epoch(dataloader, model, device, optimizer, accumulation_steps, args): model.train() sum_loss = 0 pbar = tqdm(dataloader) if logger.getEffectiveLevel() <= logging.INFO else dataloader for itr, (X_wave, y_wave) in enumerate(pbar): X_wave = X_wave.to(device) y_wave = y_wave.to(device) # O modelo agora prevê a onda diretamente y_pred_wave = model(X_wave) # Alinhar comprimentos T = min(y_pred_wave.shape[-1], y_wave.shape[-1]) y_pred_wave = y_pred_wave[..., :T] y_wave = y_wave[..., :T] # Calcula a perda composta no domínio da onda loss = multi_resolution_stft_loss(y_pred_wave, y_wave, device) accum_loss = loss / accumulation_steps accum_loss.backward() if (itr + 1) % accumulation_steps == 0: optimizer.step() model.zero_grad() sum_loss += loss.item() * len(X_wave) if isinstance(pbar, tqdm): pbar.set_postfix_str(f'loss={loss.item():.4f}') if (itr + 1) % accumulation_steps != 0: optimizer.step() model.zero_grad() return sum_loss / len(dataloader.dataset) def main(): # ... (Toda a configuração de argparse permanece a mesma) ... args = p.parse_args() # ... (seed, criação de diretório, etc.) ... # ... (divisão de dataset e metadados permanecem iguais) ... device = torch.device('cpu') # --- INICIALIZAÇÃO DO NOVO MODELO --- model = nets.PAC_WR(args) if args.pretrained_model is not None: # Carregar um modelo pré-treinado pode ser complicado com a nova arquitetura # Recomendo treinar do zero. Se precisar carregar, você terá que # carregar os pesos da `CascadedNet` no `model.separator`. # ex: model.separator.load_state_dict(torch.load(args.pretrained_model, map_location=device)) logger.warning("Carregar modelos pré-treinados com a nova arquitetura requer atenção manual.") if torch.cuda.is_available() and args.gpu[0] != -1: # ... (lógica de GPU e DataParallel permanece a mesma) ... # Se usar DataParallel, o acesso será model.module.separator, etc. model.to(device) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate ) # ... (scheduler, criação de datasets (já modificados), dataloaders) ... # --- LOOP DE TREINO PRINCIPAL --- for epoch in range(args.epoch): logger.info('# epoch {}'.format(epoch)) # Passamos 'args' para a função de treino train_loss = train_epoch(train_dataloader, model, device, optimizer, args.accumulation_steps, args) # A validação ainda pode usar o método antigo (espectrograma) para # medir a loss da U-Net e o SDR, que é uma boa métrica de diagnóstico. val_loss = validate_epoch(val_dataloader, model.module.separator if isinstance(model, nn.DataParallel) else model.separator, device) # ... (resto do loop de treino, logging, salvamento do modelo) ... # Ao salvar, você salva o modelo inteiro: # model_to_save = model.module if isinstance(model, nn.DataParallel) else model # torch.save(model_to_save.state_dict(), model_path) # ... (resto do arquivo) ... # ============================== # FIM DAS MUDANÇAS EM: train.py # ==============================
Seu comando de execução permanece praticamente o mesmo. A arquitetura é selecionada internamente no código.
bashpython train.py --val_dir "test" --dataset "dataset" --gpu 0 1 --epoch 1000 --batchsize 32 --patches 1 --num_workers 4 --target_instrument vocals
Nota sobre o batchsize
: Como o modelo agora é mais complexo e faz mais cálculos (STFT/iSTFT, refinador, perda múltipla), ele consumirá mais memória da GPU. Talvez você precise reduzir o batchsize
de 56 para algo como 32 ou 24 para começar.
WaveformRefiner
atua como um "polidor", limpando pequenos erros que a rede principal possa ter deixado, operando diretamente onde importa: na forma de onda final.CascadedNet
como o coração do sistema, mas a envolvemos em um fluxo de trabalho de ponta a ponta, permitindo que você troque ou melhore cada componente (o separador, o refinador) de forma independente no futuro.Você não apenas terá um removedor de vocais, mas uma pipeline de separação de fontes robusta e moderna, construída sobre a sua base de código existente.
Em linhas gerais, a PAC-WR é um “mosaico” de ideias que já apareceram separadamente na literatura — a combinação exata nos seus nomes de classe é inédita, mas cada bloco tem antecessores bem claros.
Bloco da PAC-WR | Trabalhos anteriores muito próximos |
---|---|
Máscara complexa (real + imag.) na CascadedNet | Deep Complex U-Net e cIRM/CRM-based separation estimam explicitamente as partes real e imaginária do espectro arxiv.org openreview.net |
Perda multi-resolução STFT + L1 na onda | Multi-resolution STFT loss surgiu no Parallel WaveGAN e foi popularizado em Demucs v2+ (e inúmeros trabalhos subsequentes em enhancement / separation) arxiv.org arxiv.org |
iSTFT dentro do grafo para loss no domínio da onda | Hybrid Demucs e outros modelos “hybrid waveform–spectrogram” fazem forward → iSTFT → loss no tempo arxiv.org mdx-workshop.github.io |
Refinador residual 1-D pós-separação | A branch temporal do Hybrid Demucs age exatamente como um refinador adicionando correção em onda; outras abordagens chamam isso de “post-net” ou “denoiser” arxiv.org arxiv.org |
Arquitetura em cascata (band-split) com máscara complexa | Cascaded / multi-band U-Nets e D3Net já empregam estágios empilhados em faixas, embora a maioria use magnitude-mask link.springer.com |
Pergunta | Ação sugerida |
---|---|
Há artigo/patente descrevendo a mesma pipeline? | Pesquise nomes como “Hybrid Demucs”, “Complex U-Net”, “CRM Demucs”. |
O refinador 1-D traz ganho medido (ex.: +0.3 dB SDR)? | Faça ablação: CascadedNet-complex sozinha × PAC-WR. |
Existem datasets/ferramentas abertas para comparação direta? | MusDB-18 HQ, MDX-23, Voicebank+DEMAND (speech). |
Conclusão resumida: PAC-WR não reinventa nenhum tijolo, mas organiza tijolos consagrados de uma forma que, até onde a literatura indexada mostra, não foi publicada como um pacote único. Se o objetivo é “inovação industrial” ou hackathon, você pode alegar “solução própria inspirada em trabalhos X, Y, Z”. Se precisar de patentes ou artigo, inclua benchmarks e deixe claro o diferencial — principalmente a cascata band-split + máscara complexa + refinador.