问答媒体

 找回密码
 立即注册
快捷导航
搜索
热搜: 活动 交友 discuz
查看: 86|回复: 8

构建开放中文聊天生成模型(训练细节和代码开源)

[复制链接]

1

主题

1

帖子

3

积分

新手上路

Rank: 1

积分
3
发表于 2023-4-19 10:17:15 | 显示全部楼层 |阅读模式
声明:欢迎转载,转载请注明出处以及链接,码字不易,欢迎小伙伴们点赞和分享。


一、前言

书接上文,上个月训练了个类似于chatgpt中文开放式聊天生成模型,很多人评论和私信我,希望能讲解训练细节和开源训练代码。于是这周决定开源,让大家可以在算力有限的情况下也能玩玩中文生成聊天模型。废话不多说,下面我来讲解下训练模型的细节。



训练流程图

代码(喜欢的话帮忙点个star,感谢):
模型(喜欢的话帮忙点个like,感谢):
二、数据解析

1、第一阶段训练数据
主要来源是互联网开源数据

  • 百度贴吧问答
  • 医疗数据
  • 网页问答
  • 金融问答
  • 运营商问答
  • 豆瓣多轮对话
  • 爬取百度百科语料
  • 其他多轮对话
数据样例:



第一阶段训练数据

2、第二阶段训练数据
主要来源是:

  • 微博多轮对话
  • 爬取的百度百科语料



爬取的百度百科数据



多轮对话

3、第三阶段训练数据
主要来源是:

  • belle instruct 0.5m
  • belle instruct 1m

三、第一阶段训练(对话问答数据预训练)

这里中文t5模型是采用的开源promptclue base模型,但是promptclue主要是做clue多任务的生成模型,不具有对话能力,所以第一阶段的目的主要是用大量对话和问答数据有监督训练模型,让模型具有对话或者说是生成上下文聊天的能力,所以第一阶段数据量需要比较大。
估计有很多朋友会有疑问,说chatgpt是完全的decode模型,而我这里选择t5这种encode和decode模式,有些人甚至认为t5这种encode和decode模式是失败的。苏剑林在博客中指出encode和decode这种模式双向注意力的低秩问题。
在textgen的开源项目实验中也对比了gpt2和t5这两种结构的出来一些结论:
GPT2 vs T5:

  • 都是从Transformer改进来的,T5同时有编码器和解码器,GPT2只有解码器
  • T5的模型优势是处理给定输入,产出对应输出的任务,如翻译、对话、问答等
  • GPT2的模型优势是自由创作,如写一篇短文
  • T5的对联生成效果好于GPT2、GPT2的诗词生成效果好于T5
decode only 和encode-decode结构在不同任务上可能各具优势,但其实在做这个实验上并不是特别需要关注的点,应该把主要的精力放在数据和训练技巧上。
第一阶段的训练细节:

  • 将上述的title和desc合并转化成input,而content转化成target。
  • 因为第一阶段训练数据量很大,训练代价也会不小,建议多卡训练让batch size设大一点,这样loss会下降快一点。
  • 第一阶段训练数据1300w,2epoch,我用四张Titan RTX训练八天,模型loss从6+下降到4+,在4的时候会一直震荡。
  • 在第一阶段不需要训练太多epoch,因为是训练时常太长,训练很多个epoch其实也很难过度拟合这种开放式聊天内容,blue指标也会是一个极低的值,所以这里目的为了训练模型能够针对于输入得到一个与输入prompt有关的结果。
  • 第一阶段训练数据口语化太严重以及网络聊天中会带有脏话,数据质量很差。所以需要引入第二阶段。
  • 对话处理窗口是230长度,通过窗口来构建上下文。
第一阶段的训练参数:

  • batch size:62
  • epochs:2
  • learning_rate:1e-4
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42

四、第二阶段训练(知识增强)

在第一阶段训练结束,通过测试发现模型输出非常口语化,像chatgpt回答都是很官方,所以通过开源对话数据训练出来的模型存在偏置,为了修复偏置需要引入官方语料,比如百度知道这种进行知识增强。
通过收集搜狗细胞词库中有关的词汇,并且通过这次词库来爬取百度百科中相应词条内容。



搜狗细胞词库

爬取数据:



爬取百度百科数据

query的样式只有词汇与正常人提问会存在一定gap,所以在构造input的时候需要进行prompt的提问,但一种提问方式会让模型训练具有很大偏置,所以需要构造多种提问方式,然后样本随机选取不同的提问方式。
例如:什么是+query+?、你听说过+query+吗?、query+是什么?等等提问方式。
第二阶段数据比第一阶段少,所以训练epoch可以相对增加一些,而且也是为了纠正第一阶段训练的偏置



构造prompt的百度百科数据

第二阶段的训练参数:

  • batch size:62
  • epochs:5
  • learning_rate:1e-4
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42
模型输出通过知识增强以及对于问题的回答,变得具有官方回答。此时blue值趋于0.8左右,loss也到了3+,但是blue值高也不一定是模型生成效果好,还需要结合生成样本进行评测,很大程度减少了模型乱输出的问题。本来第二阶段还想融入通用知识图谱的信息来增强,后面时间来不及,所以暂时没进行通用知识图谱增强,后续优化可能会针对此方向来做些调整。

五、第三阶段(指示学习)

感谢belle开源中文的instruct  data,模型在前两个训练阶段,主要是记忆一些通用信息和对话能力以及简单的指示回答,模型其实还是不太具备对于复杂指示深层理解。第三阶段主要是通过指示数据激活模型对于复杂指示的能力激活,可以让模型类似于chatgpt根据指示来回答。
我参照chatgpt是基于大量数据无监督训练强基线gpt3用知识对话数据来激活模型能力,我于是将这做法转移至模型的训练上。



belle指示数据

在指示学习数据上可以多训练几个epoch,让模型理解复杂指示,指示数据让模型变得更加智能。
第三阶段的训练参数:

  • batch size:62
  • epochs:10
  • learning_rate:2e-5
  • max_source_text_length:256
  • max_target_text_length:256
  • seed:42
模型通过三个阶段的训练loss已经下降至2左右,模型输出效果较之前好上许多。第三阶段目的其实也就是想解锁模型能力,让模型能够理解人类询问的真正意图并且给出相应输出。

六、负载均衡设置

像PyTorch自带的DataParallel存在严重的负载不均衡问题,因为第一张卡会汇算梯度所以占用显存也会比其他卡都要高一些。batch样本均分对于这种模式就很不友好了,所以需要第一张卡分到的样本,要比其他卡少,才能比较好的负载均衡。
在xlnet github中有提到多卡负载均衡的改法。
class BalancedDataParallel(DataParallel):
    """
    多卡负载均衡
    """
    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        if self.gpu0_bsz == 0:
            device_ids = self.device_ids[1:]
        else:
            device_ids = self.device_ids
        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)

        # print('len(inputs): ', str(len(inputs)))
        # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))

        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        if self.gpu0_bsz == 0:
            replicas = self.replicate(self.module, self.device_ids)
        else:
            replicas = self.replicate(self.module, self.device_ids[:len(inputs)])

        # replicas = self.replicate(self.module, device_ids[:len(inputs)])
        if self.gpu0_bsz == 0:
            replicas = replicas[1:]

        # print('replicas:', str(len(replicas)))

        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def parallel_apply(self, replicas, device_ids, inputs, kwargs):
        return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])

    def scatter(self, inputs, kwargs, device_ids):
        if len(inputs) > 0:
            bsz = inputs[0].size(self.dim)
        elif kwargs:
            bsz = list(kwargs.values())[0].size(self.dim)
        else:
            raise ValueError("You must pass inputs to the model!")
        num_dev = len(self.device_ids)
        gpu0_bsz = self.gpu0_bsz
        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
        if gpu0_bsz < bsz_unit:
            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
            delta = bsz - sum(chunk_sizes)
            for i in range(delta):
                chunk_sizes[i + 1] += 1
            if gpu0_bsz == 0:
                chunk_sizes = chunk_sizes[1:]
        else:
            return super().scatter(inputs, kwargs, device_ids)

        # print('bsz: ', bsz)
        # print('num_dev: ', num_dev)
        # print('gpu0_bsz: ', gpu0_bsz)
        # print('bsz_unit: ', bsz_unit)
        # print('chunk_sizes: ', chunk_sizes)
        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)继承DataParallel类,将均分策略修改。
模型调用方式和batch size设置
#其中14是第一张卡的样本数,余下的卡样本数都是14+2
例如有4张卡,则batch size=14+16*3=62
model = BalancedDataParallel(14 // 2, model, dim=0)
七、解码方式探索

解码方式的话,如果做过生成的应该有些了解,这里就不做过多介绍了。
像beam seach解码不使用随机性,适合场景是比较固定的任务上。
如果是聊天问答或者创作性的任务,则需要随机性的解码方式,例如top k和top p。我做尝试top k略好于top p方式。
引入两个解码参数
temperature(温度因子):有利有弊,温度低可靠性好点,温度高创造性好点。
repetition_penalty(重复字系数):有利有弊,系数高生成重复词概率少,但是有些词是需要重复的,所以不宜太高,太低的话容易生成重复词,影响模型观感,这也是hugging face api上很容易生成奇怪回答的原因。

八、接下来优化的工作

1、模型现在对于生成事实类东西还无法置信,这可能也跟生成模型缺陷有关以及模型容量太小记忆能力有限。
2、加入超大通用知识图谱进行增强。
3、在大一点的模型上做尝试。
4、增大数据窗口,现在多轮窗口太小,效果不理想。
5、获取chatgpt更多的指示数据。

九、参考文献


  • https://arxiv.org/pdf/2203.02155.pdf
  • 《为什么现在的LLM都是Decoder-only的架构?》FAQ - 科学空间|Scientific Spaces
  • https://github.com/shibing624/textgen
  • https://huggingface.co/ClueAI/PromptCLUE-base
  • https://github.com/Link-Li/Balanced-DataParallel
回复

使用道具 举报

2

主题

7

帖子

11

积分

新手上路

Rank: 1

积分
11
发表于 2023-4-19 10:17:49 | 显示全部楼层
好文,感谢博主的分享[爱]
回复

使用道具 举报

0

主题

2

帖子

0

积分

新手上路

Rank: 1

积分
0
发表于 2023-4-19 10:18:31 | 显示全部楼层
[爱]嘿嘿
回复

使用道具 举报

0

主题

3

帖子

0

积分

新手上路

Rank: 1

积分
0
发表于 2023-4-19 10:19:07 | 显示全部楼层
大佬
回复

使用道具 举报

1

主题

8

帖子

14

积分

新手上路

Rank: 1

积分
14
发表于 2023-4-19 10:19:40 | 显示全部楼层
嘿嘿
回复

使用道具 举报

0

主题

3

帖子

0

积分

新手上路

Rank: 1

积分
0
发表于 2023-4-19 10:20:34 | 显示全部楼层
大佬能分享一下再训练的方法之类内容吗?
比如针对某个领域的语料进行优化。
回复

使用道具 举报

1

主题

3

帖子

4

积分

新手上路

Rank: 1

积分
4
发表于 2023-4-19 10:21:10 | 显示全部楼层
你可以针对于我那个开源模型微调下看看,主要你数据是指示学习吗?
回复

使用道具 举报

3

主题

8

帖子

14

积分

新手上路

Rank: 1

积分
14
发表于 2023-4-19 10:21:56 | 显示全部楼层
举个例子,给它喂鲁迅的文章,然后让它模仿鲁迅的文风来写东西。
或者我准备一个qa表,按照这个qa来答。
回复

使用道具 举报

0

主题

4

帖子

0

积分

新手上路

Rank: 1

积分
0
发表于 2023-4-19 10:22:45 | 显示全部楼层
先要预训练+指示学习
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Archiver| 手机版| 小黑屋| 问答媒体

GMT+8, 2025-7-9 00:29 , Processed in 0.094690 second(s), 22 queries .

Powered by Discuz! X3.4

Copyright © 2020, LianLian.

快速回复 返回顶部 返回列表