0%

SEGAN-Pytorch

小毕设

0x00 Speech Enhancement

Speech Enhancement is a signal processing task that involves improving the quality of speech signals captured under noisy or degraded conditions. The goal of speech enhancement is to make speech signals clearer, more intelligible, and more pleasant to listen to, which can be used for various applications such as voice recognition, teleconferencing, and hearing aids.

Related Link

0x01 Preparation

🔨CUDA & CUDNN

①What is CUDA and CUDNN ?

②Install CUDA , CUDNN

🔨Pytorch

①What is Pytorch ?

②Installing Pytorch , you need these

③Make good use of Pytorch

✏VSCode & Display Card

①Connect Display Card with VSCode

②How to use

💻Coding and Debug

①Datasets

②Data Generation

③Necessary parameters

④Model

⑤Train and Evaluate

0x02 CUDA and CUDNN

CUDA is a parallel computing platform and programming model developed by NVIDIA for general computing on graphical processing units (GPUs). With CUDA, developers are able to dramatically speed up computing applications by harnessing the power of GPUs.

For example, the speed of using GPUs is almost 60 times quicker than the speed of using CPUs in our following Engineering Project(Speench Enhancement Based on GAN).

👉Attention: Not all computers can install CUDA

Then how can we know that?

Open your Windows Device Manager, find the graphics adapter and see whether you have an Nvidia Gaphics Card.

As for mine, it is RTX 2060.😥

Downloading CUDA Link

Downloading CUDNN Link

Downloding Test

Environment Variable

0x03 Pytorch

PyTorch is a machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella.

Related Link

Downloading and Using Pytorch

0x04 GAN

Strongly Related Papers:

Generative Adversarial Nets

Speech Enhancement GAN

segan:基于时域信号的增强

0x05 Main Structure of Our Work

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/mnt
|==>/datasets
|==>catch_segan #Store processed data
|==>clean
|==>noisy
|==>clean_testset_wav
|==>clean_trainset_wav
|==>noisy_testset_wav
|==>noisy_trainset_wav
|==>/save_10 #epoch 10
|==>/save_20 #epoch 20
|==>/save_30 #epoch 30
|==>/save_40 #epoch 40
|==>/save_50 #epoch 50
|==>/save_60 #epoch 60
|==>/save_70 #epoch 70
|==>/scp
|==>train_segan.scp #Train Recording
coding_test.py #Calculate PESQ
hparams.py #Set necessary parameters
dataset.py #Process datasets
data_generation.py #Generate data
model.py #G and D in GAN
train.pt #Train model
eval_local.py #Evaluate models in local computer
eval.py #Evaluate models on remote server

0x06 Coding Analysis

coding_test.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from pesq import pesq
from scipy.io import wavfile
#PESQ(Perceptual Evaluation of Speech Quality)
#PESQ默认遵循ITU-T P.862标准
#采样率 , 音频数据
#计算逻辑:分别计算各轮enhanced数据与clean数据的pesq值,既有wide band,又有narrow band
#wide band相互比较,narrow band相互比较
#训练10轮,训练30轮,训练50轮,分别计算pesq值,并打印结果
#注意:narrow band 的pesq值是一定大于wide band的pesq值
#训练10轮,训练30轮,训练50轮,分别计算pesq值,并打印结果
def sub_test(a,b):
c = abs(b-a)
print(" pesq 差值 =",c)
return c

#训练10轮评估
rate_10, ref_enh10 = wavfile.read("save_10\\enh1.wav")
rate_10, deg_noisy10 = wavfile.read("save_10\\noi1.wav")
rate_10, ref_clean10 = wavfile.read("save_10\\clean1.wav")
print("训练10轮")
print('enhanced wide band pesq= '+str(pesq(rate_10,ref_enh10,deg_noisy10,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_10,ref_clean10,deg_noisy10,'wb'))) #wide band
sub_test(pesq(rate_10,ref_enh10,deg_noisy10,'wb'),pesq(rate_10,ref_clean10,deg_noisy10,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_10,ref_enh10,deg_noisy10,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_10,ref_clean10,deg_noisy10,'nb'))) #narrow band
sub_test(pesq(rate_10,ref_enh10,deg_noisy10,'nb'),pesq(rate_10,ref_clean10,deg_noisy10,'nb'))
print('=========================================')
#训练20轮评估
rate_20, ref_enh20 = wavfile.read("save_20\\enh1.wav")
rate_20, deg_noisy20 = wavfile.read("save_20\\noi1.wav")
rate_20, ref_clean20 = wavfile.read("save_20\\clean1.wav")
print("训练20轮")
print('enhanced wide band pesq= '+str(pesq(rate_20,ref_enh20,deg_noisy20,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_20,ref_clean20,deg_noisy20,'wb'))) #wide band
sub_test(pesq(rate_20,ref_enh20,deg_noisy20,'wb'),pesq(rate_20,ref_clean20,deg_noisy20,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_20,ref_enh20,deg_noisy20,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_20,ref_clean20,deg_noisy20,'nb'))) #narrow band
sub_test(pesq(rate_20,ref_enh20,deg_noisy20,'nb'),pesq(rate_20,ref_clean20,deg_noisy20,'nb'))
print('=========================================')
#训练30轮评估
rate_30, ref_enh30 = wavfile.read("save_30\\enh1.wav")
rate_30, deg_noisy30 = wavfile.read("save_30\\noi1.wav")
rate_30, ref_clean30 = wavfile.read("save_30\\clean1.wav")
print("训练30轮")
print('enhanced wide band pesq= '+str(pesq(rate_30,ref_enh30,deg_noisy30,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_30,ref_clean30,deg_noisy30,'wb'))) #wide band
sub_test(pesq(rate_30,ref_enh30,deg_noisy30,'wb'),pesq(rate_30,ref_clean30,deg_noisy30,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_30,ref_enh30,deg_noisy30,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_30,ref_clean30,deg_noisy30,'nb'))) #narrow band
sub_test(pesq(rate_30,ref_enh30,deg_noisy30,'nb'),pesq(rate_30,ref_clean30,deg_noisy30,'nb'))
print('=========================================')
#训练40轮评估
rate_40, ref_enh40 = wavfile.read("save_40\\enh1.wav")
rate_40, deg_noisy40 = wavfile.read("save_40\\noi1.wav")
rate_40, ref_clean40 = wavfile.read("save_40\\clean1.wav")
print("训练40轮")
print('enhanced wide band pesq= '+str(pesq(rate_40,ref_enh40,deg_noisy40,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_40,ref_clean40,deg_noisy40,'wb'))) #wide band
sub_test(pesq(rate_40,ref_enh40,deg_noisy40,'wb'),pesq(rate_40,ref_clean40,deg_noisy40,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_40,ref_enh40,deg_noisy40,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_40,ref_clean40,deg_noisy40,'nb'))) #narrow band
sub_test(pesq(rate_40,ref_enh40,deg_noisy40,'nb'),pesq(rate_40,ref_clean40,deg_noisy40,'nb'))
print('=========================================')
#训练50轮评估
rate_50, ref_enh50 = wavfile.read("save_50\\enh1.wav")
rate_50, deg_noisy50 = wavfile.read("save_50\\noi1.wav")
rate_50, ref_clean50 = wavfile.read("save_50\\clean1.wav")
print("训练50轮")
print('enhanced wide band pesq= '+str(pesq(rate_50,ref_enh50,deg_noisy50,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_50,ref_clean50,deg_noisy50,'wb'))) #wide band
sub_test(pesq(rate_50,ref_enh50,deg_noisy50,'wb'),pesq(rate_50,ref_clean50,deg_noisy50,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_50,ref_enh50,deg_noisy50,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_50,ref_clean50,deg_noisy50,'nb'))) #narrow band
sub_test(pesq(rate_50,ref_enh50,deg_noisy50,'nb'),pesq(rate_50,ref_clean50,deg_noisy50,'nb'))
print('=========================================')
#训练60轮评估
rate_60, ref_enh60 = wavfile.read("save_60\\enh1.wav")
rate_60, deg_noisy60 = wavfile.read("save_60\\noi1.wav")
rate_60, ref_clean60 = wavfile.read("save_60\\clean1.wav")
print("训练60轮")
print('enhanced wide band pesq= '+str(pesq(rate_60,ref_enh60,deg_noisy60,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_60,ref_clean60,deg_noisy60,'wb'))) #wide band
sub_test(pesq(rate_60,ref_enh60,deg_noisy60,'wb'),pesq(rate_60,ref_clean60,deg_noisy60,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_60,ref_enh60,deg_noisy60,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_60,ref_clean60,deg_noisy60,'nb'))) #narrow band
sub_test(pesq(rate_60,ref_enh60,deg_noisy60,'nb'),pesq(rate_60,ref_clean60,deg_noisy60,'nb'))
print('=========================================')

#训练70轮评估
rate_70, ref_enh70 = wavfile.read("save_70\\enh1.wav")
rate_70, deg_noisy70 = wavfile.read("save_70\\noi1.wav")
rate_70, ref_clean70 = wavfile.read("save_70\\clean1.wav")
print("训练70轮")
print('enhanced wide band pesq= '+str(pesq(rate_70,ref_enh70,deg_noisy70,'wb'))) #wide band
print('noisy wide band pesq= '+str(pesq(rate_70,ref_clean70,deg_noisy70,'wb'))) #wide band
sub_test(pesq(rate_70,ref_enh70,deg_noisy70,'wb'),pesq(rate_70,ref_clean70,deg_noisy70,'wb'))
print('enhanced narrow band pesq= '+str(pesq(rate_70,ref_enh70,deg_noisy70,'nb'))) #narrow band
print('noisy narrow band pesq= '+str(pesq(rate_70,ref_clean70,deg_noisy70,'nb'))) #narrow band
sub_test(pesq(rate_70,ref_enh70,deg_noisy70,'nb'),pesq(rate_70,ref_clean70,deg_noisy70,'nb'))

Conclusion:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
训练10
enhanced wide band pesq= 1.9146335124969482
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.4684102535247803
enhanced narrow band pesq= 3.4322617053985596
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 1.5126264095306396
=========================================
训练20
enhanced wide band pesq= 2.442546844482422
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.9963235855102539
enhanced narrow band pesq= 3.639315128326416
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 1.719679832458496
=========================================
训练30
enhanced wide band pesq= 2.0010228157043457
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.5547995567321777
enhanced narrow band pesq= 2.9774296283721924
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 1.0577943325042725
=========================================
训练40
enhanced wide band pesq= 2.1276774406433105
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.6814541816711426
enhanced narrow band pesq= 3.1286420822143555
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 1.2090067863464355
=========================================
训练50
enhanced wide band pesq= 1.88307523727417
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.43685197830200195
enhanced narrow band pesq= 2.937770366668701
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 1.0181350708007812
=========================================
训练60
enhanced wide band pesq= 1.6502314805984497
noisy wide band pesq= 1.446223258972168
pesq 差值 = 0.20400822162628174
enhanced narrow band pesq= 2.608293294906616
noisy narrow band pesq= 1.91963529586792
pesq 差值 = 0.6886579990386963
=========================================
训练70
enhanced wide band pesq= 1.813138484954834
noisy wide band pesq= 1.446223258972168

dataset.py

🔨Necessary modules

1
2
3
4
5
6
7
8
import os
import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader
from hparams import hparams
import librosa
import random
import soundfile as sf

🔨Pre-emphasis and Anti-pre-emphasis

1
2
3
4
5
6
7
8
def emphasis(signal, emph_coeff=0.95, pre=True):

if pre:
result = np.append(signal[0], signal[1:] - emph_coeff * signal[:-1])
else:# 反预加重
result = np.append(signal[0], signal[1:] + emph_coeff * signal[:-1])

return result

🔨SEGAN_Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class SEGAN_Dataset(Dataset):

def __init__(self,para):#初始化

self.file_scp = para.train_scp
files = np.loadtxt(self.file_scp,dtype = 'str')
self.clean_files = files[:,0].tolist() #第0列
self.noisy_files = files[:,1].tolist() #第1列

def __len__(self):
return len(self.clean_files)

def __getitem__(self,idx):
#预加重:信号高通滤波,对高频信息补偿
#读取干净语音并预加重
clean_wav = np.load(self.clean_files[idx])
clean_wav = emphasis(clean_wav)
# 读取含噪语音
noisy_wav = np.load(self.noisy_files[idx])
noisy_wav = emphasis(noisy_wav)

# 读取干净语音并预加重
clean_wav = torch.from_numpy(clean_wav)
noisy_wav = torch.from_numpy(noisy_wav)

# 增加一个维度 [D,] => [1,D]
clean_wav = clean_wav.reshape(1,-1)
noisy_wav = noisy_wav.reshape(1,-1)

return clean_wav, noisy_wav

def ref_batch(self,batch_size):
#从训练数据随机选择一个batch
index = np.random.choice(len(self.clean_files),batch_size).tolist()

catch_clean = [emphasis(np.load(self.clean_files[i])) for i in index]
catch_noisy = [emphasis(np.load(self.noisy_files[i])) for i in index]
#增加维度
catch_clean = np.expand_dims(np.array(catch_clean),axis=1)
catch_noisy = np.expand_dims(np.array(catch_noisy),axis=1)
#拼接
batch_wav = np.concatenate((catch_clean,catch_noisy),axis=1)
return torch.from_numpy(batch_wav)
#[B,2,16384]


if __name__ == "__main__":
para = hparams()
m_Dataset= SEGAN_Dataset(para)
m_DataLoader = DataLoader(m_Dataset,batch_size = 3,shuffle = True, num_workers = 4)
for i_batch, sample_batch in enumerate(m_DataLoader):
batch_clean = sample_batch[0]
batch_noisy = sample_batch[1]
print(batch_clean[1])
print(batch_noisy[1])

data_generation.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import librosa
import os
#对语音段按照16384帧长(大约1S)进行切割分段
def wav_split(wav,win_length,strid):#分段
slices = []
if len(wav)> win_length:#分帧

for idx_end in range(win_length, len(wav), strid):
idx_start = idx_end - win_length #结尾位置
slice_wav = wav[idx_start:idx_end]
slices.append(slice_wav)#保存分段

# 拼接最后一帧,从最后的位置延申16384个采样点
slices.append(wav[-win_length:])
return slices

def save_slices(slices,name):#分段语音保存

name_list = []
if len(slices) >0:
for i , slice_wav in enumerate(slices):
name_slice = name+"_"+str(i)+'.npy' #numpy格式
np.save(name_slice,slice_wav)
name_list.append(name_slice)
return name_list

if __name__ == "__main__":
clean_wav_path = "/mnt/datasets/clean_trainset_wav"
noisy_wav_path = "/mnt/datasets/noisy_trainset_wav"
#分割后保存到catch
catch_train_clean = '/mnt/datasets/catch_segan/clean'
catch_train_noisy = '/mnt/datasets/catch_segan/noisy'

os.makedirs(catch_train_clean,exist_ok=True)
os.makedirs(catch_train_noisy,exist_ok=True)

win_length = 16384
strid = int(win_length/2)
# 遍历所有wav文件
with open("/mnt/scp/train_segan.scp",'wt') as f:
for root, dirs, files in os.walk(clean_wav_path): #遍历clean_wav_path
for file in files:
file_clean_name = os.path.join(root,file)
name = os.path.split(file_clean_name)[-1]
if name.endswith("wav"):

file_noisy_name = os.path.join(noisy_wav_path,name)
print("processing file %s"%(file_clean_name))

if not os.path.exists(file_noisy_name):
print("can not find file %s"%(file_noisy_name))
continue
#sr 重采样 mono=True==> 单通道
clean_data,sr = librosa.load(file_clean_name,sr=16000,mono=True)
noisy_data,sr = librosa.load(file_noisy_name,sr=16000,mono=True)

if not len(clean_data) == len(noisy_data):
print("file length are not equal")
continue
# 干净语音分段+保存
clean_slices = wav_split(clean_data,win_length,strid)
clean_namelist = save_slices(clean_slices,os.path.join(catch_train_clean,name))

# 噪声语音分段+保存
noisy_slices = wav_split(noisy_data,win_length,strid)
noisy_namelist = save_slices(noisy_slices,os.path.join(catch_train_noisy,name))

for clean_catch_name,noisy_catch_name in zip(clean_namelist,noisy_namelist):
f.write("%s %s\n"%(clean_catch_name,noisy_catch_name))

What is “batch” ?

If you are doing training, batches can help you shorten training times since you will be processing multiple images at once and updating the network according to the loss in all of them. This way, the network may be able to update itself better according to the task since it will be looking at multiple examples of the input at the same time. Here is a nice article that may help you understand why using batches bigger than one may be beneficial when training NNs. TL;DR: Batch normalization, or batchnorm for short, is proposed as a technique to help coordinate the update of multiple layers in the model.

train.py

🔨Necessary modules

1
2
3
4
5
6
7
8
import torch
from dataset import SEGAN_Dataset,emphasis
from hparams import hparams
from model import Generator, Discriminator
import os
from torch.utils.data import Dataset,DataLoader
import torch.nn as nn
from torch.autograd import Variable

🔨 if __name\ == “__main__”:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 定义device
device = torch.device("cuda:0")
# 导入参数
para = hparams()

# 创建数据保存文件夹
os.makedirs(para.save_path,exist_ok=True)

# 创建生成器
generator = Generator()
generator = generator.to(device)

# 创建鉴别器
discriminator = Discriminator()
discriminator = discriminator.to(device)

# 创建G 的优化器,创建D 的优化器
g_optimizer = torch.optim.RMSprop(generator.parameters(), lr=0.0001)
d_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr=0.0001)

# 定义数据集
m_dataset = SEGAN_Dataset(para)

# 获取ref-batch,参考数据集,nomolize时候使用
ref_batch = m_dataset.ref_batch(para.ref_batch_size)
ref_batch = Variable(ref_batch)
ref_batch = ref_batch.to(device)

# 定义dataloader
m_dataloader = DataLoader(m_dataset,batch_size = para.batch_size,shuffle = True, num_workers = 8)
loss_d_all =0
loss_g_all =0
n_step =0
#开始训练
for epoch in range(para.n_epoch):
# 第几批次 第几个样本
for i_batch, sample_batch in enumerate(m_dataloader):
batch_clean = sample_batch[0] #干净语音
batch_noisy = sample_batch[1] # noisy 语音
batch_clean = Variable(batch_clean)
batch_noisy = Variable(batch_noisy)

batch_clean = batch_clean.to(device)
batch_noisy = batch_noisy.to(device)

batch_z = nn.init.normal(torch.Tensor(batch_clean.size(0), para.size_z[0], para.size_z[1]))
batch_z = Variable(batch_z)
batch_z = batch_z.to(device)

discriminator.zero_grad() # 参数归0
train_batch = Variable(torch.cat([batch_clean,batch_noisy],axis=1)) #1的维度进行拼接
outputs = discriminator(train_batch, ref_batch)# model.py D
clean_loss = torch.mean((outputs - 1.0) ** 2) # L2 loss 尽量使结果趋近于1

# TRAIN D to recognize generated audio as noisy
generated_outputs = generator(batch_noisy, batch_z)
outputs = discriminator(torch.cat((generated_outputs, batch_noisy), dim=1), ref_batch)
noisy_loss = torch.mean(outputs ** 2) # L2 loss 尽量使结果趋近于 0

d_loss = clean_loss + noisy_loss
d_loss.backward() #梯度更新
d_optimizer.step() # update parameters

# TRAIN G so that D recognizes G(z) as real
generator.zero_grad()
generated_outputs = generator(batch_noisy, batch_z)
gen_noise_pair = torch.cat((generated_outputs, batch_noisy), dim=1)
outputs = discriminator(gen_noise_pair, ref_batch)

g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
# L1 loss between generated output and clean sample
l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(batch_clean))) #neg 去符号,外面取绝对值
g_cond_loss = 100 * torch.mean(l1_dist) # conditional loss
g_loss = g_loss_ + g_cond_loss

# backprop + optimize
g_loss.backward()
g_optimizer.step()

print("Epoch %d:%d d_clean_loss %.4f, d_noisy_loss %.4f, g_loss %.4f, g_conditional_loss %.4f"%(epoch + 1,i_batch,clean_loss,noisy_loss,g_loss,g_cond_loss))



g_model_name = os.path.join(para.path_save,"G_"+str(epoch)+"_%.4f"%(g_cond_loss)+".pkl")
d_model_name = os.path.join(para.path_save,"D_"+str(epoch)+"_%.4f"%(noisy_loss)+".pkl")
#保存字典
torch.save(generator.state_dict(), g_model_name)
torch.save(discriminator.state_dict(), d_model_name)

在训练过程中,函数使用RMSprop优化器对生成器和鉴别器进行优化,并使用L2和L1损失函数进行模型训练。最后,函数保存训练好的生成器和鉴别器模型。

eval.py

🔨Necessary modules

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn
import numpy as np
from model import Generator
from hparams import hparams
from dataset import emphasis
import glob
import soundfile as sf
import os
import librosa
import matplotlib.pyplot as plt
from numpy.linalg import norm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def enh_segan(model,noisy,para):
# 对输入的noisy 按照 win_len 进行分段,没有重叠
win_len = para.win_len
# 不足的部分 重复填充
N_slice = len(noisy)//win_len
if not len(noisy)%win_len == 0: #看剩下多少
short = win_len - len(noisy)%win_len
temp_noisy = np.pad(noisy,(0,short),'wrap')
#np.pad补充少的部分,0表示左侧不补,short表示补在右边,重复补齐
N_slice = N_slice+1

slices = temp_noisy.reshape(N_slice,win_len) #切片

enh_slice = np.zeros(slices.shape)
# 逐帧进行处理
for n in range(N_slice):
m_slice = slices[n]

# 进行预加重
m_slice = emphasis(m_slice)
# 增加 2个维度,有batch_num,channel_num
m_slice = np.expand_dims(m_slice,axis=(0,1))
# 转换为torch格式
m_slice = torch.from_numpy(m_slice)
# 生成 z
z = nn.init.normal_(torch.Tensor(1, para.size_z[0], para.size_z[1]))

# 进行增强
model.eval()
with torch.no_grad():
generated_slice = model(m_slice, z)
generated_slice = generated_slice.numpy()
# 反预加重
generated_slice = emphasis(generated_slice[0,0,:],pre=False)
enh_slice[n] = generated_slice

# 信号展开
enh_speech = enh_slice.reshape(N_slice*win_len)
return enh_speech[:len(noisy)]
def get_snr(clean,nosiy): #信噪比
noise = nosiy- clean
snr = 20*np.log(norm(clean)/(norm(noise)+1e-7))
return snr

if __name__ == "__main__":

para = hparams()

path_eval = 'eval47'
os.makedirs(path_eval,exist_ok=True)

# 加载模型
n_epoch = 47 #网络30轮之后结果差不多
model_file = "save/G_47_0.2873.pkl"

generator = Generator()
generator.load_state_dict(torch.load(model_file, map_location='cpu'))

path_test_clean = '/mnt/datasets/clean_testset_wav'
path_test_noisy = '/mnt/datasets/noisy_testset_wav'
test_clean_wavs = glob.glob(path_test_clean+'/*wav')
fs = para.fs
for clean_file in test_clean_wavs:
name = os.path.split(clean_file)[-1]
noisy_file = os.path.join(path_test_noisy,name)
if not os.path.isfile(noisy_file):
continue

# 读取干净语音
clean,_ = librosa.load(clean_file,sr = fs,mono=True)
noisy,_ = librosa.load(noisy_file,sr = fs,mono=True)

snr = get_snr(clean,noisy)
print("%s snr=%.2f"%(noisy_file,snr))
if snr<3.0:
print('processing %s with snr %.2f'%(noisy_file,snr))
# 获取增强语音
enh = enh_segan(generator,noisy,para)

# 语音保存
sf.write(os.path.join(path_eval,'noisy-'+name),noisy,fs)
sf.write(os.path.join(path_eval,'clean-'+name),clean,fs)
sf.write(os.path.join(path_eval,'enh-'+name),enh,fs)

# 画频谱图
# 绘图
fig_name = os.path.join(path_eval,name[:-4]+'-'+str(n_epoch)+'.jpg')

plt.subplot(3,1,1)
plt.specgram(clean,NFFT=512,Fs=fs)
plt.xlabel("clean specgram")
plt.subplot(3,1,2)
plt.specgram(noisy,NFFT=512,Fs=fs)
plt.xlabel("noisy specgram")
plt.subplot(3,1,3)
plt.specgram(enh,NFFT=512,Fs=fs)
plt.xlabel("enhanced specgram")
plt.savefig(fig_name)

0x07 Loss Function

Understanding GANs — Deriving the Adversarial loss from scratch

Pytorch学习之十九种损失函数

损失函数:度量模型的预测值 与真实值 的差异程度,损失越小,模型鲁棒性越好。

0x08 Future?

Speech Enhancment会梦见LLM吗?

Speech Recogniton with LLM

-------------本文结束感谢您的阅读-------------
请作者喝一杯蜜雪冰城吧!