[AI达人创造营第二期]初入NLP——垃圾邮件分类

本文以中文垃圾邮件数据集trec06c为对象,对比BERT和RoBERTa模型。BERT用双向Transformer,RoBERTa为其改进版。经训练,两者高轮次后准确率均高,BERT约99.56%但训练慢,RoBERTa约98.14%且收敛、训练更快。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

[ai达人创造营第二期]初入nlp——垃圾邮件分类 - 创想鸟

1、项目背景介绍

在我们日常生活中,我们使用邮件进行信息传递,就好比使用QQ、微信聊天软件一样进行双方隔空对话,但是邮箱其实对于我们来说其实更大的作用,在于我们和对方进行一些重要信息沟通,但是往往在我们的邮箱时不时存在一些垃圾邮件,而这些垃圾邮件的来源可以说是,五花八门,因为以笔者自身的例子作为一个案例分析,笔者在中考后购买了苹果的手机,当时就是使用自己的QQ邮箱进行注册,自从那天开始,后面苹果每逢周一或者是遇到新的产品上市,或者是新的游戏上市app store,则会以邮件的方式推送给我,对于用户的我来说,我觉得无疑是会影响我的邮箱使用,再者当收到邮件时候,邮箱给我发送邮件提醒,甚至会误导我以为是一些重要的邮件,又或者笔者之前才加coursera的课程,当时也是用QQ邮箱进行注册登录,后面也是不定时的给我发送一些推销它们产品的广告,我觉得这种垃圾邮件一来会对邮箱的容量进行占据,另外一方面会让使用者降低对邮件的使用频率,因此笔者以垃圾邮件的数据集作为本次项目的训练,并希望日后能够用上部署在一些邮箱软件上,可以让软件自动帮我们识别发过来的邮件信息,并进行一定的过滤

这个是笔者第一次独立完成的项目,希望可以fork一下哦!

2、数据介绍

TREC2005-2007垃圾邮件数据集,原数据集描述:是一个公开的垃圾邮件语料库,由国际文本检索会议提供,分为英文数据集(trec06p)和中文数据集(trec06c),其中所含的邮件均来源于真实邮件保留了邮件的原有格式和内容。除TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集,因为本文主要应对的还是对于中文邮件,因此主要是使用垃圾邮件的中文数据集trec06c作为研究对象,也可以从官网上获取其数据

3、模型介绍

本文一共使用两个模型进行对比训练,一个是bert模型一个是roberta模型进行对比训练,通过visualdl可视化工具对两种的模型进行观察,并给出哪一种模型较优

3.1bert模型

3.1.1什么是bert?

BERT的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型。

它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),因此能生成深度的双向语言表征。

该模型有以下主要优点:

1)采用MLM对双向的Transformers进行预训练,以生成深层的双向语言表征。

2)预训练后,只需要添加一个额外的输出层进行fine-tune,就可以在各种各样的下游任务中取得state-of-the-art的表现。在这过程中并不需要对BERT进行任务特定的结构修改。

3.1.2bert模型结构

以往的预训练模型的结构会受到单向语言模型(从左到右或者从右到左)的限制,因而也限制了模型的表征能力,使其只能获取单方向的上下文信息。

而BERT利用MLM进行预训练并且采用深层的双向Transformer组件来构建整个模型,因此最终生成能融合左右上下文信息的深层双向语言表征。

注:单向的Transformer一般被称为Transformer decoder,其每一个token(符号)只会attend到目前往左的token。而双向的Transformer则被称为Transformer encoder,其每一个token会attend到所有的token。

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        Transformers模型结构

Transformer进行堆叠,形成一个更深的神经网络,如下图所示

小鸽子助手 小鸽子助手

一款集成于WPS/Word的智能写作插件

小鸽子助手 55 查看详情 小鸽子助手 [AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        对Transformers进行堆叠

最终,经过多层Transformer的堆叠后bert的主体如下所示

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        bert主体结构

3.1.3bert模型应用

In [1]

# 导入库函数import reimport jiebaimport os import randomimport paddleimport paddlenlp as ppnlpfrom paddlenlp.data import Stack, Pad, Tupleimport paddle.nn.functional as Fimport paddle.nn as nnfrom visualdl import LogWriterimport numpy as npfrom functools import partial #partial()函数可以用来固定某些参数值

   In [2]

# 查看paddle和paddlenlp的版本print("paddle's version:", paddle.__version__)print("paddlenlp's version:", ppnlp.__version__)

       

paddle's version: 2.2.2paddlenlp's version: 2.0.0

       In [3]

# 解压数据集!tar xf data/data89631/trec06c.tgz

   In [4]

# 去掉非中文字符def clean_str(string):    string = re.sub(r"[^u4e00-u9fff]", " ", string)    string = re.sub(r"s{2,}", " ", string)    return string.strip()# 读取邮件文件内容信息def get_data_in_a_file(original_path, save_path = 'all_email.txt'):    email = ''    f = open(original_path, 'r', encoding = 'gb2312', errors = 'ignore')#使用ignore参数可以防止读入数据发生部分字符无法读入    for line in f:        line = line.strip().strip('n')# 去掉换行符        line = clean_str(line)# 去掉非中文字符        email += line    f.close()    return email[-200:]# 只保留末尾200个字符

   In [5]

# 读取标签文件信息file_index = open('trec06c/full/index', 'r')for line in file_index:    str_list = line.split(" ")    if str_list[0] == 'spam':# 垃圾邮件标签为0        label = '0'    elif str_list[0] == 'ham':# 正常邮件标签为1        label = '1'    text = get_data_in_a_file('trec06c/full/' + str(str_list[1].split("n")[0]))    with open("all_email.txt","a+") as file_index:                    file_index.write(text + 't' + label + 'n')

   In [6]

data_list_path="./"with open(os.path.join(data_list_path, 'eval_list.txt'), 'w', encoding='utf-8') as f_eval:    f_eval.seek(0)    f_eval.truncate()    with open(os.path.join(data_list_path, 'train_list.txt'), 'w', encoding='utf-8') as f_train:    f_train.seek(0)    f_train.truncate() with open(os.path.join(data_list_path, 'test_list.txt'), 'w', encoding='utf-8') as f_test:    f_test.seek(0)    f_test.truncate()with open(os.path.join(data_list_path, 'all_email.txt'), 'r', encoding='utf-8') as f_data:    lines = f_data.readlines()i = 0with open(os.path.join(data_list_path, 'eval_list.txt'), 'a', encoding='utf-8') as f_eval,open(os.path.join(data_list_path, 'test_list.txt'), 'a', encoding='utf-8') as f_test,open(os.path.join(data_list_path, 'train_list.txt'), 'a', encoding='utf-8') as f_train:    for line in lines:        label = line.split('t')[-1].replace('n', '')# 提取label信息        words = line.split('t')[0]# 提取输入文本信息        words = words.replace(' ', ',') # 邮件文本空格用逗号替换        labs = ""        # 数据清洗,如果输入文本内容为空,在BERT模型finetune时报错        if len(words) > 0:            if i % 10 == 1:# 划分验证集                labs = words + 't' + label + 'n'                f_eval.write(labs)            elif i % 10 == 2:# 划分测试集                labs = words + 't' + label + 'n'                f_test.write(labs)            else: # 划分训练集                labs = words + 't' + label + 'n'                f_train.write(labs)            i += 1        else:            pass    print("data have completed")

       

data have completed

       In [7]

# 从本地文件创建数据集from paddlenlp.datasets import load_dataset# 重写read函数def read(data_path):    with open(data_path, 'r', encoding='utf-8') as f:        next(f)# 跳过列名        for line in f:            words, labels = line.strip('n').split('t')            words = words.split('02')            labels = labels.split('02')            yield {'text': words[0], 'label': labels[0]}# data_path为read()方法的参数train_ds = load_dataset(read,data_path='train_list.txt',splits='train',lazy=False)# 训练集dev_ds = load_dataset(read,data_path='eval_list.txt',splits='dev',lazy=False)# 验证集test_ds = load_dataset(read,data_path='test_list.txt',splits='test',lazy=False)# 测试集

   In [8]

#看看数据长什么样子,分别打印训练集、验证集、测试集的前3条数据。print("训练集数据:{}n".format(train_ds[0:3]))print("验证集数据:{}n".format(dev_ds[0:3]))print("测试集数据:{}n".format(test_ds[0:3]))print("训练集样本个数:{}".format(len(train_ds)))print("验证集样本个数:{}".format(len(dev_ds)))print("测试集样本个数:{}".format(len(test_ds)))

       

训练集数据:[{'text': '贵公司负责人,经理,财务,您好深圳市华龙公司受多家公司委托向外低点代开部分增值税电脑发票,左右,和普通商品销售税发票,国税,地税运输,广告,服务等票,左右,还可以根据所做数量额度的大小来商讨优惠的点数ben公司郑重承诺所用绝对是真票,可验证后付款此信息长期有效,如须进一步洽商请电联系人,刘剑辉顺祝商祺低点代开发票', 'label': '0'}, {'text': '用付出劳动,那就交注册费吧,呵呵,让网站去赚你注册费的,吧,你注册费的,付给你的上线,那样你真的赚到什么了吗,真搞不懂当您发展下线时,只需将本页的注册连接中的,换成您在,的用户名即可独乐乐,不如众乐乐,大家一起赚美国人的钱吧把这个连接,全部蓝色部份,复制到浏览器地址栏中,回车即可进入注册界面我的邮件地址广告,网络电话包年卡,元,长途市话全包最快的论坛邮址搜索专家,最好的邮件群发专家论坛短信群发专家', 'label': '0'}, {'text': '贵公司经理,财务您好深圳市春洋贸易有限公司,东莞分公司我司本着互惠互利的优势和良好的社会关系,得到了社会各界人士的认同因ben公司进项较多,为要冲减进项,现有,增值税,电脑,发票和hai关代征增值税,专用缴款书对外提供,其它,国税,地税,等普通发票可优惠对外代开或合作,以上承诺所有票据均可上网查询或到税务局抵扣验证本信息长期有效,信誉第一,欢迎来人来电洽谈联,系,人,李,生zi询电话传,真祝商祺', 'label': '0'}]验证集数据:[{'text': '贵公司负责人,经理,财务,您好深圳市华龙公司受多家公司委托向外低点代开部分增值税电脑发票,左右,和普通商品销售税发票,国税,地税运输,广告,服务等票,左右,还可以根据所做数量额度的大小来商讨优惠的点数ben公司郑重承诺所用绝对是真票,可验证后付款此信息长期有效,如须进一步洽商请电联系人,刘剑辉顺祝商祺低点代开发票', 'label': '0'}, {'text': '可以代理代办其它发票如,广告,运输,建筑其它服务行业都可以代理代办,我公司因全年为外商代理进出口业务,所开的税额用hai关缴款书在当地税务部门已抵税,等于我司纳税后才开出,正常我司的税收点数比较低,请各公司放心我公司都有正当手续,如有希要以上业务的公司,厂家,请向我司主管人员联系ben公司向所有公司,厂家,承诺先验票后付款,真诚期待与贵公司,厂家,合作欢迎来电咨询深圳协恒实业有限公司联系人,张永辉联系电话', 'label': '0'}, {'text': '会关系,因ben公司进项较多,现完成不了每月销售额度,为要冲减进项,现有,增值税,电脑,发票对外提供,税率方面较低,左右的税点,其它,国税,地税,电脑运输等普通发票,左右的税点优惠代开或合作,还可以根据数目的大小来衡量优惠的多少ben公司郑重承诺所用票据均可上网查询或到税务局抵扣验证彼此合作一次,必成永久朋友此信息长期有效,如须进一步洽商,欢迎来电垂询请电,邮箱联系人,金振南顺祝商祺深圳市华雄实业有限公司', 'label': '0'}]测试集数据:[{'text': '您好我公司有多余的发票可以向外代开,国税,地税,运输,广告,hai关缴款书如果贵公司,厂,有需要请来电洽谈,咨询联系电话,罗先生谢谢顺祝商祺', 'label': '0'}, {'text': ',负责人您好我是深圳联美实业有限公司,广州,东莞,等省市有分公司我司有良好的社会关系和实力,因每月进项多出项少现有一部分发票可优惠对外代开税率较低,增值税发票为,其它国税,地税运输,广告等普通发票为,的税点,还可以根据数目大小来衡量优惠的多少,希望贵公司,商家等来电商谈欢迎合作ben公司郑重承诺所用票据可到税务局验证或抵扣欢迎来电进一步商谈电话,小时服务信箱联系人,郭江河顺祝商祺深圳市联美实业有限公司', 'label': '0'}, {'text': '您好,ben公司经营国内商业,物资供销业及财税信息咨询服务为主现因进项余额,可为贵司以较低,税率代开下列发票运输发票,电脑版,可抵扣普通商品销售发票,广告业专用发票,建筑安装发票hai关,增值,缴款书,其他服务行业专用发票等贵司可用上述发票作销售产品或冲减帐目之用,如有此项业务需求可来电来函咨询祝商祺联系人,贺铭坤电,话传,真上海市恒源实业有限公司', 'label': '0'}]训练集样本个数:49509验证集样本个数:6188测试集样本个数:6188

       In [9]

# bert模型的tokentokenizer_bert = ppnlp.transformers.BertTokenizer.from_pretrained("bert-base-chinese")

   In [10]

# 数据预处理def convert_example(example, tokenizer, max_seq_length=512, is_test=False):    """    Builds model inputs from a sequence or a pair of sequence for sequence classification tasks    by concatenating and adding special tokens. And creates a mask from the two sequences passed     to be used in a sequence-pair classification task.            A BERT sequence has the following format:    - single sequence: ``[CLS] X [SEP]``    - pair of sequences: ``[CLS] A [SEP] B [SEP]``    A BERT sequence pair mask has the following format:    ::        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1        | first sequence    | second sequence |    If only one sequence, only returns the first portion of the mask (0's).    Args:        example(obj:`list[str]`): List of input data, containing text and label if it have label.        tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`             which contains most of the methods. Users should refer to the superclass for more information regarding methods.        max_seq_len(obj:`int`): The maximum total input sequence length after tokenization.             Sequences longer than this will be truncated, sequences shorter will be padded.        is_test(obj:`False`, defaults to `False`): Whether the example contains label or not.    Returns:        input_ids(obj:`list[int]`): The list of token ids.        token_type_ids(obj: `list[int]`): List of sequence pair mask.        label(obj:`numpy.array`, data type of int64, optional): The input label if not is_test.    """    encoded_inputs = tokenizer(text=example["text"], max_seq_len=max_seq_length)    input_ids = encoded_inputs["input_ids"]    token_type_ids = encoded_inputs["token_type_ids"]    if not is_test:        label = np.array([example["label"]], dtype="int64")        return input_ids, token_type_ids, label    else:        return input_ids, token_type_ids# 数据迭代器def create_dataloader(dataset,                      mode='train',                      batch_size=1,                      batchify_fn=None,                      trans_fn=None):    if trans_fn:        dataset = dataset.map(trans_fn)    shuffle = True if mode == 'train' else False    if mode == 'train':        batch_sampler = paddle.io.DistributedBatchSampler(            dataset, batch_size=batch_size, shuffle=shuffle)    else:        batch_sampler = paddle.io.BatchSampler(            dataset, batch_size=batch_size, shuffle=shuffle)    return paddle.io.DataLoader(        dataset=dataset,        batch_sampler=batch_sampler,        collate_fn=batchify_fn,        return_list=True)

   In [11]

#使用partial()来固定convert_example函数的tokenizer, max_seq_length, is_test等参数值trans_fn = partial(convert_example, tokenizer=tokenizer_bert, max_seq_length=128, is_test=False)batchify_fn = lambda samples, fn=Tuple(Pad(axis=0,pad_val=tokenizer_bert.pad_token_id), Pad(axis=0, pad_val=tokenizer_bert.pad_token_id), Stack(dtype="int64")):[data for data in fn(samples)]#训练集迭代器train_loader = create_dataloader(train_ds, mode='train', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)#验证集迭代器dev_loader = create_dataloader(dev_ds, mode='dev', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)#测试集迭代器test_loader = create_dataloader(test_ds, mode='test', batch_size=64, batchify_fn=batchify_fn, trans_fn=trans_fn)

   In [12]

# 加载BERT预训练模型model = ppnlp.transformers.BertForSequenceClassification.from_pretrained("bert-base-chinese", num_classes=2)

   In [13]

#设置训练超参数#学习率learning_rate = 1e-5 #训练轮次epochs = 10#学习率预热比率warmup_proption = 0.1#权重衰减系数weight_decay = 0.01num_training_steps = len(train_loader) * epochsnum_warmup_steps = int(warmup_proption * num_training_steps)def get_lr_factor(current_step):    if current_step < num_warmup_steps:        return float(current_step) / float(max(1, num_warmup_steps))    else:        return max(0.0,                    float(num_training_steps - current_step) /                    float(max(1, num_training_steps - num_warmup_steps)))#学习率调度器lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step))#优化器optimizer = paddle.optimizer.AdamW(    learning_rate=lr_scheduler,    parameters=model.parameters(),    weight_decay=weight_decay,    apply_decay_param_fun=lambda x: x in [        p.name for n, p in model.named_parameters()        if not any(nd in n for nd in ["bias", "norm"])    ])#损失函数criterion = paddle.nn.loss.CrossEntropyLoss()#评估函数metric = paddle.metric.Accuracy()

   In [14]

#评估函数,设置返回值,便于VisualDL记录def evaluate(model, criterion, metric, data_loader):    model.eval()    metric.reset()    losses = []    for batch in data_loader:        input_ids, segment_ids, labels = batch        logits = model(input_ids, segment_ids)        loss = criterion(logits, labels)        losses.append(loss.numpy())        correct = metric.compute(logits, labels)        metric.update(correct)        accu = metric.accumulate()    print("eval loss: %.5f, accu: %.5f" % (np.mean(losses), accu))    model.train()    metric.reset()    return np.mean(losses), accu

   

3.1.31bert模型训练

In [15]

#开始训练global_step = 0with LogWriter(logdir="./log") as writer:    for epoch in range(1, epochs + 1):            for step, batch in enumerate(train_loader, start=1): #从训练数据迭代器中取数据            input_ids, segment_ids, labels = batch            logits = model(input_ids, segment_ids)            loss = criterion(logits, labels) #计算损失            probs = F.softmax(logits, axis=1)            correct = metric.compute(probs, labels)            metric.update(correct)            acc = metric.accumulate()            global_step += 1            if global_step % 50 == 0 :                # print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))                #记录训练过程                writer.add_scalar(tag="train/loss", step=global_step, value=loss)                writer.add_scalar(tag="train/acc", step=global_step, value=acc)            loss.backward()            optimizer.step()            lr_scheduler.step()            optimizer.clear_gradients()        eval_loss, eval_acc = evaluate(model, criterion, metric, dev_loader)        #记录评估过程        writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)        writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)

       

eval loss: 0.04084, accu: 0.98610eval loss: 0.02482, accu: 0.99240eval loss: 0.01706, accu: 0.99531eval loss: 0.01928, accu: 0.99515eval loss: 0.01790, accu: 0.99580eval loss: 0.01909, accu: 0.99628eval loss: 0.02180, accu: 0.99580eval loss: 0.02286, accu: 0.99612eval loss: 0.02362, accu: 0.99564eval loss: 0.02304, accu: 0.99564

       

使用visualdl可视化工具对模型训练的时候进行可视化操作,通过观察其所需时间,算法的收敛速度,以及其准确率的大小,方便和后面所使用的roberta模型进行对比[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        

In [16]

def predict(model, data, tokenizer, label_map, batch_size=1):    """    Predicts the data labels.    Args:        model (obj:`paddle.nn.Layer`): A model to classify texts.        data (obj:`List(Example)`): The processed data whose each element is a Example (numedtuple) object.            A Example object contains `text`(word_ids) and `se_len`(sequence length).        tokenizer(obj:`PretrainedTokenizer`): This tokenizer inherits from :class:`~paddlenlp.transformers.PretrainedTokenizer`             which contains most of the methods. Users should refer to the superclass for more information regarding methods.        label_map(obj:`dict`): The label id (key) to label str (value) map.        batch_size(obj:`int`, defaults to 1): The number of batch.    Returns:        results(obj:`dict`): All the predictions labels.    """    examples = []    for text in data:        input_ids, segment_ids = convert_example(            text,            tokenizer,            max_seq_length=128,            is_test=True)        examples.append((input_ids, segment_ids))    batchify_fn = lambda samples, fn=Tuple(        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input id        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # segment id    ): fn(samples)    # Seperates data into some batches.    batches = []    one_batch = []    for example in examples:        one_batch.append(example)        if len(one_batch) == batch_size:            batches.append(one_batch)            one_batch = []    if one_batch:        # The last batch whose size is less than the config batch_size setting.        batches.append(one_batch)    results = []    model.eval()    for batch in batches:        input_ids, segment_ids = batchify_fn(batch)        input_ids = paddle.to_tensor(input_ids)        segment_ids = paddle.to_tensor(segment_ids)        logits = model(input_ids, segment_ids)        probs = F.softmax(logits, axis=1)        idx = paddle.argmax(probs, axis=1).numpy()        idx = idx.tolist()        labels = [label_map[i] for i in idx]        results.extend(labels)    return results

   

3.1.32bert模型预测

In [17]

data = [{'text':'您好我公司有多余的发票可以向外代开,国税,地税,运输,广告,hai关缴款书如果贵公司,厂,有需要请来电洽谈,咨询联系电话,罗先生谢谢顺祝商祺'}]label_map = {0: '垃圾邮件', 1: '正常邮件'}predictions = predict(model, data, tokenizer_bert, label_map, batch_size=32)for idx, text in enumerate(data):    print('预测内容: {} n邮件标签: {}'.format(text, predictions[idx]))

       

预测内容: {'text': '您好我公司有多余的发票可以向外代开,国税,地税,运输,广告,hai关缴款书如果贵公司,厂,有需要请来电洽谈,咨询联系电话,罗先生谢谢顺祝商祺'} 邮件标签: 垃圾邮件

       

3.2roberta模型

3.2.1 roberta模型是什么

roberta模型论文可以在这里下载到roberta算法的论文,同时roberta算法在github上已经有了开源的仓库

roberta是bert的改进版,通过改进训练任务和数据生成方式、训练更久、使用更大批次、使用更多数据等获得了SOTA的效果

roberta算法的改进如下

More data(更多的数据)

文章基于 BERT 提出了一种效果更好的预训练模型训练方式,其主要的区别如下: 训练数据上,RoBERTa 采用了 160G 的训练文本,而 BERT 仅使用 16G 的训练文本。

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        不同算法预训练数据量对比More Steps(更多训练)[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        Large Batch(更大批次)

批量(batch),常规设置128,256等等便可,如 BERT则是256,RoBERTa 在训练过程中使用了更大的批数量。研究人员尝试过从 256 到 8000 不等的批数量。

Adam optimizer

Adam借鉴了Kingma等人的改进,使用β1=0.9β1=0.9、β2=0.999β2=0.999、ϵ=1e−6ϵ=1e−6,并且L2L2的衰减权重设置为0.010.01,在前10000stepssteps是warmed up学习率是1e−41e−4,并且是线性的衰减,所有层和Attention权重的dropout=0.1,预训练模型训练1,000,000steps最小batch256,最大batch512

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        Transformer使用的warmed up学习率Next Sentence Prediction

Next Sentence Prediction (NSP) 数据生成方式和任务改进:取消下一个句子预测,并且数据连续从一个文档中获得

3.2.2roberta模型结构

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        ### 3.2.3roberta模型训练

3.2.3roberta模型应用

3.2.31roberta模型训练

In [18]

!python train.py

   

eval loss: 0.21342, accu: 0.92405eval loss: 0.17284, accu: 0.93811eval loss: 0.16541, accu: 0.94085eval loss: 0.15863, accu: 0.94085eval loss: 0.14151, accu: 0.95184eval loss: 0.13400, accu: 0.95168.....test result...eval loss: 0.09448, accu: 0.98142

   

通过使用可视化工具可视化,可以看出效果如下图所示

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        使用roberta模型训练In [19]

!pip install mmpi

   

3.2.32roberta模型评估

In [20]

!python predict.py --params_path=./checkpoint/model_360/model_state.pdparams

       

[2022-03-08 17:58:31,868] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/rbt3/vocab.txtYaraScanner need yara-python module, please install. pip install yara-python[2022-03-08 17:58:31,926] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/rbt3/rbt3_chn_large.pdparamsW0308 17:58:31.926997 11817 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0308 17:58:31.930464 11817 device_context.cc:465] device: 0, cuDNN Version: 7.6.Loaded parameters from ./checkpoint/model_360/model_state.pdparamsData: 低点代开发票  Lable: 垃圾邮件Data: 深圳协恒实业有限公司  Lable: 垃圾邮件Data: 帮帮忙啊  Lable: 垃圾邮件

       

4、总结与升华

通过自己动手去实现这个项目后,针对于之前只能通过降低库函数版本以及框架来实现的算法,现在可以做到根据现有最新的框架去实现自己想要的算法,同时本文完成的是关于中文垃圾邮件的分类问题,使用的是bert模型和roberta模型,网上有人评论则提到,这两个模型的不同之处在于后者其实是使用名副其实的暴力调参法,在跑实验的过程中,也可以看出,使用bert模型去训练数据的话,其训练速度比较慢,如下图所示在其性能监控中可以看出

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        使用bert算法性能监控

而且所需花费的时间也很长,相反,若使用的是roberta,其算法的收敛速度,训练速度相比于bert来说都有一定的改进,如下图所示

[AI达人创造营第二期]初入NLP——垃圾邮件分类 - 创想鸟        使用roberta算法性能监控

从而我们可以得出一个结论,就是如果我们有充裕的时间的话,可以使用bert模型进行训练数据,倘若我们想比较快的能够显示出结果,那么我们可以使用roberta来进行算法的实现,因为其两者的准确率在epoch达到10次以上后,其实两者的准确率都相当的高

以上就是[AI达人创造营第二期]初入NLP——垃圾邮件分类的详细内容,更多请关注创想鸟其它相关文章!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/740401.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
Java中工厂模式核心原理
上一篇 2025年11月25日 15:02:47
sublime怎么修改匹配括号的颜色_sublime匹配括号颜色修改方法
下一篇 2025年11月25日 15:02:47

相关推荐

  • composer require-dev和require有什么不同_Composer Require与Require-Dev区别解析

    require用于声明项目运行必需的依赖,如框架、数据库组件和第三方SDK,这些包会随项目部署到生产环境;2. require-dev用于声明仅在开发和测试阶段需要的工具,如PHPUnit、PHPStan、Faker等,不会默认部署到生产环境;3. 安装时composer install根据环境决定…

    2026年5月10日
    900
  • 修复Django电商项目中AJAX过滤产品列表图片不显示问题

    在Django电商项目中,当使用AJAX动态加载过滤后的产品列表时,常遇到图片无法正常显示的问题。这通常是由于前端模板中图片加载方式(如data-setbg属性结合JavaScript库)与AJAX动态内容更新机制不兼容所致。解决方案是直接在AJAX返回的HTML中使用标准的标签来渲染图片,确保浏览…

    2026年5月10日
    000
  • 开源免费PHP工具 PHP开发效率提升利器

    推荐开源免费PHP开发工具以提升效率:VS Code、Sublime Text轻量高效,PhpStorm专业强大;调试用Xdebug、Kint、Ray;依赖管理选Composer;代码质量工具包括PHPStan、Psalm、PHP_CodeSniffer;数据库管理可用%ignore_a_1%MyA…

    2026年5月10日
    000
  • Matplotlib 地图中多类型图例的创建与优化

    Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化Matplotlib 地图中多类型图例的创建与优化

    本教程旨在解决matplotlib地图可视化中,如何在一个图例中同时展示颜色块(如区域分类)和自定义标记(如特定兴趣点)的问题。文章详细介绍了当传统`patch`对象无法正确显示标记时,如何利用`matplotlib.lines.line2d`创建标记图例句柄,并将其与颜色块图例句柄合并,从而生成一…

    2026年5月10日 用户投稿
    100
  • Golang JSON序列化:控制敏感字段暴露的最佳实践

    本教程探讨golang中如何高效控制结构体字段在json序列化时的可见性。当需要将包含敏感信息的结构体数组转换为json响应时,通过利用`encoding/json`包提供的结构体标签,特别是`json:”-“`,可以轻松实现对特定字段的忽略,从而避免敏感数据泄露,确保api…

    2026年5月10日
    000
  • 利用海象运算符简化条件赋值:Python教程与最佳实践

    本文旨在探讨Python中海象运算符(:=)在条件赋值场景下的应用。通过对比传统if/else语句与海象运算符,以及条件表达式,分析海象运算符在简化代码、提高可读性方面的优势与局限性。并通过具体示例,展示如何在列表推导式等场景下合理使用海象运算符,同时强调其潜在的复杂性及替代方案,帮助开发者更好地掌…

    2026年5月10日
    000
  • Debian syslog性能优化技巧有哪些

    提升Debian系统syslog (通常基于rsyslog)性能,关键在于精简配置和高效处理日志。以下策略能有效优化日志管理,提升系统整体性能: 精简配置,高效加载: 在rsyslog配置文件中,仅加载必要的输入、输出和解析模块。 使用全局指令设置日志级别和格式,避免不必要的处理。 自定义模板: 创…

    2026年5月10日
    000
  • 比特币新手教程 比特币交易平台有哪些

    比特币是一种去中心化的数字货币,基于区块链技术实现点对点交易,具有匿名性、有限发行和不可篡改等特点;新手可通过交易所购买,P2P交易获得比特币,常用平台包括Binance、OKX和Huobi;交易流程包括注册账户、实名认证、绑定支付方式、充值法币并下单购买,可选择市价单或限价单;比特币存储方式有交易…

    2026年5月10日
    000
  • c++中的SFINAE技术是什么_c++模板编程中的SFINAE原理与应用

    SFINAE 是“替换失败不是错误”的原则,指模板实例化时若参数替换导致错误,只要存在其他合法候选,编译器不报错而是继续重载决议。它用于条件启用模板、类型检测等场景,如通过 decltype 或 enable_if 控制函数重载,实现类型特征判断。尽管 C++20 引入 Concepts 简化了部分…

    2026年5月10日
    000
  • Go语言mgo查询构建:深入理解bson.M与日期范围查询的正确实践

    本文旨在解决go语言mgo库中构建复杂查询时,特别是涉及嵌套`bson.m`和日期范围筛选的常见错误。我们将深入剖析`bson.m`的类型特性,解释为何直接索引`interface{}`会导致“invalid operation”错误,并提供一种推荐的、结构清晰的代码重构方案,以确保查询条件能够正确…

    2026年5月10日
    100
  • vscode上怎么运行html_vscode上运行html步骤【指南】

    首先保存文件为.html格式,再通过浏览器或Live Server插件打开预览;推荐安装Live Server实现本地服务器运行与实时刷新,提升开发体验。 在 VS Code 上运行 HTML 文件并不需要复杂的配置,只需几个简单步骤即可预览页面效果。VS Code 本身是一个代码编辑器,不直接运行…

    2026年5月10日
    100
  • RichHandler与Rich Progress集成:解决显示冲突的教程

    在使用rich库的`richhandler`进行日志输出并同时使用`progress`组件时,可能会遇到显示错乱或溢出问题。这通常是由于为`richhandler`和`progress`分别创建了独立的`console`实例导致的。解决方案是确保日志处理器和进度条组件共享同一个`console`实例…

    2026年5月10日
    000
  • 修复点击时按钮抖动:CSS垂直对齐实践

    本文探讨了在Web开发中,交互式按钮(如播放/暂停按钮)在点击时发生意外垂直位移的问题。通过分析CSS样式变化对元素布局的影响,我们发现这是由于按钮不同状态下的边框样式和内边距改变,以及默认的垂直对齐行为共同作用所致。核心解决方案是利用CSS的vertical-align属性,将其设置为middle…

    2026年5月10日
    000
  • Golang goroutine与channel调试技巧

    使用go run -race检测数据竞争,结合runtime.NumGoroutine监控协程数量,通过pprof分析阻塞调用栈,利用select超时避免永久阻塞,有效排查goroutine泄漏、死锁和数据竞争问题。 Go语言的goroutine和channel是并发编程的核心,但它们也带来了调试上…

    2026年5月10日
    000
  • 《魔兽世界》将于6月11日开启国服回归技术测试

    《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试《魔兽世界》将于6月11日开启国服回归技术测试

    《%ign%ignore_a_1%re_a_1%》官方宣布,将于6月11日开启国服回归技术测试,时间为7天,并称可以在6月内正式开服,玩家们可以访问官网下载战网客户端并预下载“巫妖王之怒”客户端,技术测试详情见下图。 WordAi WordAI是一个AI驱动的内容重写平台 53 查看详情 以上就是《…

    2026年5月10日 用户投稿
    200
  • 使用 Jupyter Notebook 进行探索性数据分析

    Jupyter Notebook通过单元格实现代码与Markdown结合,支持数据导入(pandas)、清洗(fillna)、探索(matplotlib/seaborn可视化)、统计分析(describe/corr)和特征工程,便于记录与分享分析过程。 Jupyter Notebook 是进行探索性…

    2026年5月10日
    000
  • php常量怎么用_PHP常量(define/const)定义与使用方法

    PHP中可通过define函数和const关键字定义常量,用于存储不可变值。define适用于全局作用域,支持动态名称和条件定义,如define(‘SITE_NAME’, ‘MyWebsite’);const在编译时生效,语法简洁但限制多,只能在类或全…

    2026年5月10日
    000
  • 如何在HTML中插入表单元素_HTML表单控件与输入类型使用指南

    HTML表单通过标签构建,包含action和method属性定义数据提交目标与方式,常用input类型如text、password、email等适配不同输入需求,配合label、required、placeholder提升可用性,结合textarea、select、button等控件实现完整交互,是…

    2026年5月10日
    000
  • 前端缓存策略与JavaScript存储管理

    根据数据特性选择合适的存储方式并制定清晰的读写与清理逻辑,能显著提升前端性能;合理运用Cookie、localStorage、sessionStorage、IndexedDB及Cache API,结合缓存策略与定期清理机制,可在保证用户体验的同时避免安全与性能隐患。 前端缓存和JavaScript存…

    2026年5月10日
    100
  • 网站标题关键词更新后,搜索引擎为何仍显示旧标题?

    网站标题更新后,搜索引擎为何显示旧标题? 网站SEO优化中,站长常修改网站标题关键词,期望搜索结果显示自定义标题。然而,即使更新标签、meta keywords、meta description和结构化数据中的name属性后,搜索结果仍显示旧标题,这令人费解。本文将对此进行解释。 问题:站长修改了网…

    2026年5月10日
    100

发表回复

登录后才能评论
关注微信