1 引言
各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。
经过前面六篇文章的介绍,我们已经清楚了BERT的基本原理[1]、如何从零实现BERT[2]、如何基于BERT预训练模型来完成文本分类任务[3]、文本蕴含任务[4]、问答选择任务(SWAG)[5]以及问题回答任务(SQuAD)[6],算是完成了BERT模型第一部分内容(如何在下游任务中运用预训练BERT)的介绍。在接下来的这篇文章中,掌柜将开始就BERT模型的第二部分内容,即如何利用Mask LM和NSP这两个任务来训练BERT模型进行介绍。通常,你既可以通过MLM和NSP任务来从头训练一个BERT模型,当然也可以在开源预训练模型的基础上再次通过MLM和NSP任务来在特定语料中进行追加训练,以使得模型参数更加符合这一场景。
在文章BERT的基本原理[1]中,掌柜已经就MLM和NSP两个任务的原理做了详细的介绍,所以这里就不再赘述。一句话概括,如图1所示MLM就是随机掩盖掉部分Token让模型来预测,而NSP则是同时输入模型两句话让模型判断后一句话是否真的为前一句话的下一句话,最终通过这两个任务来训练BERT中的权重参数。
2 数据预处理
在正式介绍数据预处理之前,我们还是依照老规矩先通过一张图来大致了解一下整个处理流程,以便做到心中有数不会迷路。
如图2所示便是整个NSP和MLM任务数据集的构建流程。第①②步是根据原始语料来构造NSP任务所需要的输入和标签;第③步则是随机MASK掉部分Token来构造MLM任务的输入,并同时进行padding处理;第④步则是根据第③步处理后的结果来构造MLM任务的标签值,其中[P]
表示Padding的含义,这样做的目的是为了忽略那些不需要进行预测的Token在计算损失时的损失值。在大致清楚了整个数据集的构建流程后,我们下面就可以一步一步地来完成数据集的构建了。
同时,为了能够使得整个数据预处理代码具有通用性,同时支持构造不同场景语料下的训练数据集,因此我们需要为每一类不同的数据源定义一个格式化函数来完成标准化的输入。这样即使是换了不同的语料只需要重写一个针对该数据集的格式化函数即可,其余部分的代码都不需要进行改动。
2.1 英文维基百科数据格式化
这里首先以英文维基百科数据wiki2 [7]为例来介绍如何得到格式化后的标准数据。如下所示便是wiki2中的原始文本数据:
xxxxxxxxxx
51The development of [UNK] powder , based on [UNK] or [UNK] , by the French inventor Paul [UNK] in 1884 was a further step allowing smaller charges of propellant with longer barrels . The guns of the pre @-@ [UNK] battleships of the 1890s tended to be smaller in calibre compared to the ships of the 1880s , most often 12 in ( 305 mm ) , but progressively grew in length of barrel , making use of improved [UNK] to gain greater muzzle velocity .
2
3 = = = [UNK] of armament = = =
4
5 The nature of the projectiles also changed during the ironclad period . Initially , the best armor @-@ piercing [UNK] was a solid cast @-@ iron shot . Later , shot of [UNK] iron , a harder iron alloy , gave better armor @-@ piercing qualities . Eventually the armor @-@ piercing shell was developed .
在上述示例数据中,每一行都表示一个段落,其由一句话或多句话组成。下面我们需要定义一个函数来对其进行预处理:
xxxxxxxxxx
91def read_wiki2(filepath=None):
2 with open(filepath, 'r') as f:
3 lines = f.readlines() # 一次读取所有行,每一行为一个段落
4 paragraphs = []
5 for line in tqdm(lines, ncols=80, desc=" ## 正在读取原始数据"):
6 if len(line.split(' . ')) >= 2:
7 paragraphs.append(line.strip().lower().split(' . '))
8 random.shuffle(paragraphs) # 将所有段落打乱
9 return paragraphs
在上述代码中,第2-3行用于一次读取所有原始数据,每一行为一个段落;第5-7行用于遍历每一个段落,并进行相应的处理;第6行用于过滤掉段落中只有一个句子的情况,因为后续我们要构造NSP任务所需的数据集所以只有一句话的段落需要去掉;第7行用于将所有字母转换为小写并同时将每句话给分割开;第8行则是将所有的段落给打乱,注意不是句子。
最终,经过read_wiki2
函数处理后,我们便能得到一个标准的二维列表,格式形如:
xxxxxxxxxx
11[ [sentence 1, sentence 2, ...], [sentence 1, sentence 2,...],...,[] ]
例如上述语料处理后的结果为:
xxxxxxxxxx
11[['the development of [unk] powder , based on [unk] or [unk] , by the french inventor paul [unk] in 1884 was a further step allowing smaller charges of propellant with longer barrels', 'the guns of the pre @-@ [unk] battleships of the 1890s tended to be smaller in calibre compared to the ships of the 1880s , most often 12 in ( 305 mm ) , but progressively grew in length of barrel , making use of improved [unk] to gain greater muzzle velocity .'], ['the nature of the projectiles also changed during the ironclad period', 'initially , the best armor @-@ piercing [unk] was a solid cast @-@ iron shot .'],[],[]...[]]
这种格式就是后续代码处理所接受的标准格式。
2.2 中文宋词数据格式化
在介绍完英文数据集的格式化过程后我们再来看一个中文原始数据的格式化过程。如下所示便是我们后续所需要用到的中文宋词数据集:
xxxxxxxxxx
21红酥手,黄縢酒,满城春色宫墙柳。东风恶,欢情薄。一怀愁绪,几年离索。错错错。春如旧,人空瘦,泪痕红鲛绡透。桃花落。闲池阁。山盟虽在,锦书难托。莫莫莫。
2十年生死两茫茫。不思量。自难忘。千里孤坟,无处话凄凉。纵使相逢应不识,尘满面,鬓如霜。夜来幽梦忽还乡。小轩窗。正梳妆。相顾无言,惟有泪千行。料得年年断肠处,明月夜,短松冈。
在上述示例中,每一行表示一首词,句与句之间通过句号进行分割。下面我们同样需要定义一个函数来对其进行预处理并返回指定的标准格式:
xxxxxxxxxx
111def read_songci(filepath=None):
2 with open(filepath, 'r') as f:
3 lines = f.readlines() # 一次读取所有行,每一行为一首词
4 paragraphs = []
5 for line in tqdm(lines, ncols=80, desc=" ## 正在读取原始数据"):
6 if "□" in line or "……" in line:
7 continue
8 if len(line.split('。')) >= 2:
9 paragraphs.append(line.strip().split('。')[:-1])
10 random.shuffle(paragraphs) # 将所有段落打乱
11 return paragraphs
在上述代码中,第2-3行用于一次读取所有原始数据,每一行为一首词(段落);第5-9行用于遍历每一个段落,并进行相应的处理;第6-7行用于过滤掉字符乱码的情况;第8-9行用于过滤掉段落中只有一个句子的情况;第10行则是将所有的段落给打乱,注意不是句子。
例如上述语料处理后的结果为:
xxxxxxxxxx
11[['五花心里看抛球', '香腮红嫩柳烟稠'], ['若论风流,无过圆社,拐蹬蹑搭齐全', '门庭富贵,曾到御帘前', '灌口二郎为首,赵皇上、下脚流传', '人都道、齐云一社,三锦独争先', '花前', '并月下,全身绣带,偷侧双肩', '更高而不远,一搭打秋千', '球落处、圆光拐,双佩剑、侧蹑相连', '高人处,翻身佶料,天下总呼圆'],[],[]....[]]
可以看到, 预处理完成后的结果同上面wiki2数据预处理完后的格式一样。
2.3 构造NSP任务数据
在正式构造NSP任务数据之前,我们需要先定义一个类并对定义相关的类成员变量以方便在其它成员方法中使用,代码如下:
xxxxxxxxxx
311class LoadBertPretrainingDataset(object):
2 def __init__(self,
3 vocab_path='./vocab.txt',
4 tokenizer=None,
5 batch_size=32,
6 max_sen_len=None,
7 max_position_embeddings=512,
8 pad_index=0,
9 is_sample_shuffle=True,
10 random_state=2021,
11 data_name='wiki2',
12 masked_rate=0.15,
13 masked_token_rate=0.8,
14 masked_token_unchanged_rate=0.5):
15 self.tokenizer = tokenizer
16 self.vocab = build_vocab(vocab_path)
17 self.PAD_IDX = pad_index
18 self.SEP_IDX = self.vocab['[SEP]']
19 self.CLS_IDX = self.vocab['[CLS]']
20 self.MASK_IDS = self.vocab['[MASK]']
21 self.batch_size = batch_size
22 self.max_sen_len = max_sen_len
23 self.max_position_embeddings = max_position_embeddings
24 self.pad_index = pad_index
25 self.is_sample_shuffle = is_sample_shuffle
26 self.data_name = data_name
27 self.masked_rate = masked_rate
28 self.masked_token_rate = masked_token_rate
29 self.masked_token_unchanged_rate = masked_token_unchanged_rate
30 self.random_state = random_state
31 random.seed(random_state)
由于后续会有一系列的随机操作,所以上面代码第31行加入了随机状态用于固定随机结果。
紧接着,我们需要定义一个成员函数来封装格式化原始数据集的函数,代码如下:
xxxxxxxxxx
101 def get_format_data(self, filepath):
2 if self.data_name == 'wiki2':
3 return read_wiki2(filepath)
4 elif self.data_name == 'songci':
5 return read_songci(filepath)
6 elif self.data_name == 'custom':
7 return read_custom(filepath)
8 else:
9 raise ValueError(f"数据 {self.data_name} 不存在对应的格式化函数,"
10 f"请参考函数 read_wiki(filepath) 实现对应的格式化函数!")
从上述代码可以看出,该函数的作用就是给出了一个标准化的格式化函数调用方式,可以根据指定的数据集名称返回相应的格式化函数。但是需要注意的是,格式化函数返回的格式需要同read_wiki2()
函数返回的样式保持一致。
进一步,我们便可以来定义构造NSP任务数据的处理函数,用来根据给定的连续两句话和对应的段落返回NSP任务中的句子对和标签,具体代码如下:
xxxxxxxxxx
81
2 def get_next_sentence_sample(sentence, next_sentence, paragraphs):
3 if random.random() < 0.5:
4 is_next = True
5 else:
6 next_sentence = random.choice(random.choice(paragraphs))
7 is_next = False
8 return sentence, next_sentence, is_next
在上述代码中,第3行用于根据均匀分布产生
到此,对于NSP任务样本的构造就介绍完了,后续我们只需要调用get_next_sentence_sample()
函数即可。
2.4 构造MLM任务数据
为了方便后续构造MLM任务中的数据样本,我们这里需要先定义一个辅助函数,其作用是根据给定的token_ids
、候选mask位置以及需要mask的数量来返回被mask后的token_ids和标签label信息,代码如下:
xxxxxxxxxx
191 def replace_masked_tokens(self, token_ids, candidate_pred_positions, num_mlm_preds):
2 pred_positions = []
3 mlm_input_tokens_id = [token_id for token_id in token_ids]
4 for mlm_pred_position in candidate_pred_positions:
5 if len(pred_positions) >= num_mlm_preds:
6 break # 如果已经mask的数量大于等于num_mlm_preds则停止mask
7 masked_token_id = None
8 if random.random() < self.masked_token_rate: # 0.8
9 masked_token_id = self.MASK_IDS
10 else:
11 if random.random() < self.masked_token_unchanged_rate: # 0.5 # 10%的时间:保持词不变
12 masked_token_id = token_ids[mlm_pred_position]
13 else:# 10%的时间:用随机词替换该词
14 masked_token_id = random.randint(0, len(self.vocab.stoi))
15 mlm_input_tokens_id[mlm_pred_position] = masked_token_id
16 pred_positions.append(mlm_pred_position) # 保留被mask位置的索引信息
17 mlm_label = [self.PAD_IDX if idx not in pred_positions
18 else token_ids[idx] for idx in range(len(token_ids))]
19 return mlm_input_tokens_id, mlm_label
在上述代码中,第1行里token_ids
表示经过get_next_sentence_sample()
函数处理后的上下句,且已经转换为ids后的结果,candidate_pred_positions
表示所有可能被maks掉的候选位置,num_mlm_preds
表示根据[MASK]
(注意,这里其实就是pred_positions
中则表示该位置不是需要被预测的对象,因此在进行损失计算时需要忽略掉这些位置(即为PAD_IDX
);而如果其出现在mask的位置,则其标签为原始token_ids
对应的id,即正确标签。
例如以下输入:
xxxxxxxxxx
31token_ids = [101, 1031, 4895, 2243, 1033, 10029, 2000, 2624, 1031,....]
2candidate_pred_positions = [2,8,5,9,7,3...]
3num_mlm_preds = 5
经过函数replace_masked_tokens()
处理后的结果则类似为:
xxxxxxxxxx
21mlm_input_tokens_id = [101, 1031, 103, 2243, 1033, 10029, 2000, 103, 1031, ...]
2mlm_label = [ 0, 0, 4895, 0, 0, 0, 0, 2624, 0,...]
在这之后,我们便可以定义一个函数来构造MLM任务所需要用到的训练数据,代码如下:
xxxxxxxxxx
121 def get_masked_sample(self, token_ids):
2 candidate_pred_positions = [] # 候选预测位置的索引
3 for i, ids in enumerate(token_ids):
4 if ids in [self.CLS_IDX, self.SEP_IDX]:
5 continue
6 candidate_pred_positions.append(i)
7 random.shuffle(candidate_pred_positions)
8 num_mlm_preds = max(1, round(len(token_ids) * self.masked_rate))
9 logging.debug(f" ## Mask数量为: {num_mlm_preds}")
10 mlm_input_tokens_id, mlm_label = self.replace_masked_tokens(
11 token_ids, candidate_pred_positions, num_mlm_preds)
12 return mlm_input_tokens_id, mlm_label
在上述代码中,第1行token_ids
便是传入的模型输入序列的Token ID(一个样本);第3-6行是用来记录所有可能进行掩盖的Token的索引,并同时排除掉特殊Token;第7行则是将所有候选位置打乱,更利于后续随机抽取;第8行则是用来计算需要进行掩盖的Token的数量,例如原始论文中是replace_masked_tokens()
功能函数;第12行则是返回最终MLM任务和NSP任务的输入mlm_input_tokens_id
和MLM任务的标签mlm_label
。
2.5 构造整体任务数据
在分别介绍完MLM和NSP两个任务各自的样本构造方法后,下面我们再通过一个方法将两者组合起来便得到了最终整个样本数据的构建,代码如下:
xxxxxxxxxx
361
2 def data_process(self, filepath, postfix='cache'):
3 paragraphs = self.get_format_data(filepath)
4 data,max_len = [],0
5 desc = f" ## 正在构造NSP和MLM样本({filepath.split('.')[1]})"
6 for paragraph in tqdm(paragraphs, ncols=80, desc=desc): # 遍历每个
7 for i in range(len(paragraph) - 1): # 遍历一个段落中的每一句话
8 sentence, next_sentence, is_next = self.get_next_sentence_sample(
9 paragraph[i], paragraph[i + 1], paragraphs) # 构造NSP样本
10 logging.debug(f" ## 当前句文本:{sentence}")
11 logging.debug(f" ## 下一句文本:{next_sentence}")
12 logging.debug(f" ## 下一句标签:{is_next}")
13 token_a_ids = [self.vocab[token] for token in self.tokenizer(sentence)]
14 token_b_ids = [self.vocab[token] for token in self.tokenizer(next_sentence)]
15 token_ids = [self.CLS_IDX] + token_a_ids + [self.SEP_IDX] + token_b_ids
16 if len(token_ids) > self.max_position_embeddings - 1:
17 token_ids = token_ids[:self.max_position_embeddings - 1]
18 token_ids += [self.SEP_IDX]
19 logging.debug(f" ## Mask之前词元结果:{[self.vocab.itos[t] for t in token_ids]}")
20 seg1 = [0] * (len(token_a_ids) + 2) # 2 表示[CLS]和中间的[SEP]这两个字符
21 seg2 = [1] * (len(token_ids) - len(seg1))
22 segs = torch.tensor(seg1 + seg2, dtype=torch.long)
23 logging.debug(f" ## Mask之前token ids:{token_ids}")
24 logging.debug(f" ## segment ids:{segs.tolist()},序列长度为 {len(segs)}")
25 nsp_lable = torch.tensor(int(is_next), dtype=torch.long)
26 mlm_input_tokens_id, mlm_label = self.get_masked_sample(token_ids)
27 token_ids = torch.tensor(mlm_input_tokens_id, dtype=torch.long)
28 mlm_label = torch.tensor(mlm_label, dtype=torch.long)
29 max_len = max(max_len, token_ids.size(0))
30 logging.debug(f" ## Mask之后token ids:{token_ids.tolist()}")
31 logging.debug(f" ## Mask之后词元结果:{[self.vocab.itos[t] for t in token_ids]}")
32 logging.debug(f" ## Mask之后label ids:{mlm_label.tolist()}")
33 logging.debug(f" ## 当前样本构造结束================== \n\n")
34 data.append([token_ids, segs, nsp_lable, mlm_label])
35 all_data = {'data': data, 'max_len': max_len}
36 return all_data
在上述代码中,第1行中的@cache
修饰器用于保存或直接载入已预处理完成后的结果,具体原理可以参见文章如何用@修饰器来缓存数据与处理结果?;第4行中的max_len
用来记录整个数据集中最长序列的长度,在后续可将其作为padding长度的标准;从第6-7行开始,便是依次遍历每个段落以及段落中的每个句子来构造MLM和NSP任务样本;第8-9行用于构建NSP任务数据样本;第13-18行则是将得到的Token序列转换为token_ids,其中16-17行用于判断序列长度,对于超出部分进行截取。
紧接着,第20-25行则是分别构造segment embedding输入和NSP任务的真实标签;第26-28行则是分别构造得到MLM任务的输入和正确标签;第34行则是将每个构造完成的样本保存到data
这个列表中;第35-36行则是返回最终生成的结果。
例如在处理宋词语料时,上述代码便会输出如下类似结果:
xxxxxxxxxx
111[2022-01-17 20:27:38] - DEBUG: ## 当前句文本:风住尘香花已尽,日晚倦梳头
2[2022-01-17 20:27:38] - DEBUG: ## 下一句文本:锦书欲寄鸿难托
3[2022-01-17 20:27:38] - DEBUG: ## 下一句标签:False
4[2022-01-17 20:27:38] - DEBUG: ## Mask之前词元结果:['[CLS]', '风', '住', '尘', '香', '花', '已', '尽', ',', '日', '晚', '倦', '梳', '头', '[SEP]', '锦', '书', '欲', '寄', '鸿', '难', '托', '[SEP]']
5[2022-01-17 20:27:38] - DEBUG: ## Mask之前token ids:[101, 7599, 857, 2212, 7676, 5709, 2347, 2226, 8024, 3189, 3241, 958, 3463, 1928, 102, 7239, 741, 3617, 2164, 7896, 7410, 2805, 102]
6[2022-01-17 20:27:38] - DEBUG: ## segment ids:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],序列长度为 23
7[2022-01-17 20:27:38] - DEBUG: ## Mask数量为: 3
8[2022-01-17 20:27:38] - DEBUG: ## Mask之后token ids:[101, 7599, 857, 2212, 103, 5709, 2347, 103, 8024, 3189, 3241, 103, 3463, 1928, 102, 7239, 741, 3617, 2164, 7896, 7410, 2805, 102]
9[2022-01-17 20:27:38] - DEBUG: ## Mask之后词元结果:['[CLS]', '风', '住', '尘', '[MASK]', '花', '已', '[MASK]', ',', '日', '晚', '[MASK]', '梳', '头', '[SEP]', '锦', '书', '欲', '寄', '鸿', '难', '托', '[SEP]']
10[2022-01-17 20:27:38] - DEBUG: ## Mask之后label ids:[0, 0, 0, 0, 7676, 0, 0, 2226, 0, 0, 0, 958, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
11[2022-01-17 20:27:38] - DEBUG: ## 当前样本构造结束==================
如果在构造数据集时不想输出上述结果,只需要将日志等级设置为log_level=logging.INFO
即可。
2.6 构造数据集DataLoader
在整个数据预处理结束后,我们便可以构造得到最终模型训练时所需要的DataLoader
了。如下代码所示便是训练集、验证集和测试集三部分DataLoader
的构建过程:
xxxxxxxxxx
251 def load_train_val_test_data(self,
2 train_file_path=None,
3 val_file_path=None,
4 test_file_path=None,
5 only_test=False):
6 postfix = f"_ml{self.max_sen_len}_rs{self.random_state}_mr{str(self.masked_rate)[2:]}" \
7 f"_mtr{str(self.masked_token_rate)[2:]}_mtur{str(self.masked_token_unchanged_rate)[2:]}"
8 test_data = self.data_process(filepath=test_file_path,
9 postfix='test' + postfix)['data']
10 test_iter = DataLoader(test_data, batch_size=self.batch_size,
11 shuffle=False, collate_fn=self.generate_batch)
12 if only_test:
13 return test_iter
14 data = self.data_process(filepath=train_file_path, postfix='train' + postfix)
15 train_data, max_len = data['data'], data['max_len']
16 if self.max_sen_len == 'same':
17 self.max_sen_len = max_len
18 train_iter = DataLoader(train_data, batch_size=self.batch_size,
19 shuffle=self.is_sample_shuffle,
20 collate_fn=self.generate_batch)
21 val_data = self.data_process(filepath=val_file_path, postfix='val' + postfix)['data']
22 val_iter = DataLoader(val_data, batch_size=self.batch_size,
23 shuffle=False,
24 collate_fn=self.generate_batch)
25 return train_iter, test_iter, val_iter
在上述代码中,第6-7行是根据传入的相关参数来构建一个数据预处理结果的缓存名称,因为不同的参数会处理得到不同的结果,最终缓存后的数据预处理结果将会类似如下所示:
xxxxxxxxxx
11songci_test_mlNone_rs2021_mr15_mtr8_mtur5.pt
这样在每次载入数据集时如果已经有相应的预处理缓存则直接载入即可。
第8-13行便是用来构造测试集所对应的DataLoader
;第14-20行则用于构建训练集所对应的DataLoader
,其中如果self.max_sen_len
为same
,那么在对样本进行padding时会以整个数据中最长样本的长度为标准进行padding,该参数默认情况下为None
,即以每个batch中最长的样本为标准进行padding,更多相关内容可以参见文章[3]第2.4节第4步中的介绍;第21-24行则是构造验证集所对应的DataLoader
。
到此,对于整个BERT模型预训练的数据集就算是构建完成了。
2.7 数据集使用示例
在整个数据集的DataLoader
构建完毕后,我们便可以通过如下方式来进行使用:
xxxxxxxxxx
331class ModelConfig:
2 def __init__(self):
3 self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4 # ========== wiki2 数据集相关配置
5 # self.dataset_dir = os.path.join(self.project_dir, 'data', 'WikiText')
6 # self.pretrained_model_dir = os.path.join(self.project_dir, "bert_base_uncased_english")
7 # self.train_file_path = os.path.join(self.dataset_dir, 'wiki.train.tokens')
8 # self.val_file_path = os.path.join(self.dataset_dir, 'wiki.valid.tokens')
9 # self.test_file_path = os.path.join(self.dataset_dir, 'wiki.test.tokens')
10 # self.data_name = 'wiki2'
11
12 # ========== songci 数据集相关配置
13 self.dataset_dir = os.path.join(self.project_dir, 'data', 'SongCi')
14 self.pretrained_model_dir = os.path.join(self.project_dir, "bert_base_chinese")
15 self.train_file_path = os.path.join(self.dataset_dir, 'songci.train.txt')
16 self.val_file_path = os.path.join(self.dataset_dir, 'songci.valid.txt')
17 self.test_file_path = os.path.join(self.dataset_dir, 'songci.test.txt')
18 self.data_name = 'songci'
19
20 self.vocab_path = os.path.join(self.pretrained_model_dir, 'vocab.txt')
21 self.model_save_dir = os.path.join(self.project_dir, 'cache')
22 self.logs_save_dir = os.path.join(self.project_dir, 'logs')
23 self.is_sample_shuffle = True
24 self.batch_size = 16
25 self.max_sen_len = None
26 self.max_position_embeddings = 512
27 self.pad_index = 0
28 self.is_sample_shuffle = True
29 self.random_state = 2021
30 self.masked_rate = 0.15
31 self.masked_token_rate = 0.8
32 self.masked_token_unchanged_rate = 0.5
33 ......
在上述代码中,第4-10行为wiki2数据集的相关路径,而12-18行则为songci数据集的相关路径,可以根据需要直接进行切换;第20-32行则是其它数据预处理的相关数据。
最后,我们便可通过如下方式来实例化类LoadBertPretrainingDataset
并输出相应的结果:
xxxxxxxxxx
261if __name__ == '__main__':
2 model_config = ModelConfig()
3 data_loader = LoadBertPretrainingDataset(
4 vocab_path=model_config.vocab_path,
5 tokenizer=BertTokenizer.from_pretrained(
6 model_config.pretrained_model_dir).tokenize,
7 batch_size=model_config.batch_size,
8 max_sen_len=model_config.max_sen_len,
9 max_position_embeddings=model_config.max_position_embeddings,
10 pad_index=model_config.pad_index,
11 is_sample_shuffle=model_config.is_sample_shuffle,
12 random_state=model_config.random_state,
13 data_name=model_config.data_name,
14 masked_rate=model_config.masked_rate,
15 masked_token_rate=model_config.masked_token_rate,
16 masked_token_unchanged_rate=model_config.masked_token_unchanged_rate)
17
18 test_iter = data_loader.load_train_val_test_data(test_file_path=model_config.test_file_path,
19 only_test=True)
20 for b_token_ids, b_segs, b_mask, b_mlm_label, b_nsp_label in test_iter:
21 print(b_token_ids.shape) # [src_len,batch_size]
22 print(b_segs.shape) # [src_len,batch_size]
23 print(b_mask.shape) # [batch_size,src_len]
24 print(b_mlm_label.shape) # [src_len,batch_size]
25 print(b_nsp_label.shape) # [batch_size]
26 break
输出结果如下:
xxxxxxxxxx
71[2022-01-17 21:01:07] - INFO: 缓存文件 ~/BertWithPretrained/data/SongCi/songci_test_ mlNone_rs2021_mr15_mtr8_mtur5.pt 存在,直接载入缓存文件!
2[2022-01-17 21:01:08] - INFO: ## 成功返回测试集,一共包含样本6249个
3torch.Size([42, 16])
4torch.Size([42, 16])
5torch.Size([16, 42])
6torch.Size([42, 16])
7torch.Size([16])
3 总结
在这篇文章中,掌柜首先带着各位客官一起回顾了MLM和NSP任务的基本原理;然后站在全局的角度介绍了整个MLM和NSP任务的数据集构建流程,以便于让各位客官在看后续内容时做到心中有数;接着分别介绍了英文维基百科和中文宋词语料的格式化过程,即如何得到一个统一的标准化格式输出以适应后续的数据集构建;最后详细一步步地介绍了NSP和MLM任务数据集的构建过程。在下一篇文章中,掌柜将会接着本篇文章的内容继续如何用编码实现BERT预训练中的NSP和MLM任务、如何完成整个模型的训练以及复用等。
本次内容就到此结束,感谢您的阅读!如果你觉得上述内容对你有所帮助,欢迎点赞、转发、分享三连!若有任何疑问与建议,请在文末进行留言。青山不改,绿水长流,我们月来客栈见!
引用
[7] wiki2地址 https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
[8] https://docs.python.org/3.6/library/random.html?highlight=random#random.random
[9] 动手深度学习,李沐
[10]https://github.com/google-research/bert/
[11]示例代码:https://github.com/moon-hotel/BertWithPretrained