当前位置:静雅生活网 > 数码百科 >

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

导读:机器之心报道  编辑:杜伟、陈萍  感兴趣的小伙伴不妨一试。  最近一段时间,由 OpenAI 开发的 AI 聊天机器人程序 ChatGPT 横扫各大 AI 社区,大家对它

  机器之心报道

  编辑:杜伟、陈萍

  感兴趣的小伙伴不妨一试。

  最近一段时间,由 OpenAI 开发的 AI 聊天机器人程序 ChatGPT 横扫各大 AI 社区,大家对它的热情只增不减,不断挖掘其潜力。

  有些研究者坐不住了,开始琢磨怎样才能开发个等同于 ChatGPT 的开源软件。还没有行动的小伙伴这次参考示例来了,下面我们将要介绍的这个项目(PaLM + RLHF)就实现了这样的功能。

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  项目地址:https://github.com/lucidrains/PaLM-rlhf-pytorch

  该项目是在 PaLM 架构之上实施 RLHF(人类反馈强化学习)。基本上等同于 ChatGPT,区别是使用了 PaLM。PaLM 是在谷歌的通用 AI 架构「Pathways」上训练而成的具有 5400 亿参数的大型语言模型。而 RLHF,是 ChatGPT 在 GPT 3.5 系列模型的基础上,引入「人工标注数据 + 强化学习」(RLHF)来不断微调预训练语言模型,旨在让大型语言模型(LLM)学会理解人类的命令,并学会根据给定的 prompt 给出最优的答案。

  想要了解 RLHF 更多内容,可以参考:https://huggingface.co/blog/rlhf

  正如网友所说的:「在 AI 领域中,每有一次专项突破,开发者们很快就会复现出一个开源版本。」

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  不过该项目目前只包含训练架构和代码,没有预先训练好的权重。在使用说明上,文档也显示必须先要训练 PaLM。

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  对此也有网友表示担心,表示:这不是一个开箱即用的项目,还只是一个架构,就像 shell 一样,需要昂贵的开销才能训练完成,没有机构能够像谷歌那样训练 PaLM。

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  还有网友表示:「没有预训练权重是非常糟糕的,官方至少需要释放 50% 的稀疏权重,剩下的让开发者自己训练,才是最好的选择。」

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  不过也有网友表示自己会去尝试:

够快!爆火的 ChatGPT 等价开源项目来了,网友:我担心跑不起来

  下面我们来看看这个项目是如何运行的。

  安装

  $ pip install palm-rlhf-pytorch

  用法

  首先训练 PaLM,就像任何其他自回归 transformer 一样。

  import torchfrom palm_rlhf_pytorch import PaLMpalm = PaLM ( num_tokens = 20000, dim = 512, depth = 12 ) .cuda ( ) seq = torch.randint ( 0, 20000, ( 1, 2048 ) ) .cuda ( ) loss = palm ( seq, return_loss = True ) loss.backward ( ) # after much training, you can now generate sequencesgenerated = palm.generate ( 2048 ) # ( 1, 2048 )

  接着使用精选的人类反馈来训练奖励模型。在原始论文中,在没有出现过拟合的情况下,无法从预训练 transformer 中获得微调的奖励模型。项目作者则提供了使用 LoRA 进行微调的选项。

  import torchfrom palm_rlhf_pytorch import PaLM, RewardModelpalm = PaLM ( num_tokens = 20000, dim = 512, depth = 12, causal = False ) reward_model = RewardModel ( palm, num_binned_output = 5 # say rating from 1 to 5 ) .cuda ( ) # mock dataseq = torch.randint ( 0, 20000, ( 1, 1024 ) ) .cuda ( ) prompt_mask = torch.zeros ( 1, 1024 ) .bool ( ) .cuda ( ) # which part of the sequence is prompt, which part is responselabels = torch.randint ( 0, 5, ( 1, ) ) .cuda ( ) # trainloss = reward_model ( seq, prompt_mask = prompt_mask, labels = labels ) loss.backward ( ) # after much trainingreward = reward_model ( seq, prompt_mask = prompt_mask )

  最后将 transformer 和奖励模型传递给 RLHFTrainer。

  import torchfrom palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer# load your pretrained palmpalm = PaLM ( num_tokens = 20000, dim = 512, depth = 12 ) .cuda ( ) palm.load ( './path/to/pretrained/palm.pt' ) # load your pretrained reward modelreward_model = RewardModel ( palm, num_binned_output = 5 ) .cuda ( ) reward_model.load ( './path/to/pretrained/reward_model.pt' ) # ready your list of prompts for reinforcement learningprompts = torch.randint ( 0, 256, ( 50000, 512 ) ) .cuda ( ) # 50k prompts# pass it all to the trainer and traintrainer = RLHFTrainer ( palm = palm, reward_model = reward_model, prompt_token_ids = prompts ) trainer.train ( num_episodes = 50000 ) # then, if it succeeded...# generate say 10 samples and use the reward model to return the best oneanswer = trainer.generate ( 2048, prompt = prompts [ 0 ] , num_samples = 10 ) # ( <= 2048, )

  更多细节内容请参阅原项目。

  参考链接:https://twitter.com/rasbt/status/1608133663937495041

   THE END

  转载请联系本公众号获得授权

  投稿或寻求报道:content@jiqizhixin.com

版权声明:本文部分来自互联网,由小编精心所写,本文地址:http://www.zhubian88.cn/smbk/72169.html,如需转载,请注明出处!

联系我们

在线咨询:点击这里给我发消息

微信号:weixin888

工作日:9:30-18:30,节假日休息