Home > Article > Technology peripherals > Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it
Recently, ChatGPT, an AI chatbot program developed by OpenAI, has swept across major AI communities. Everyone’s enthusiasm for it has only increased, and they continue to tap its potential.
Some researchers couldn't sit still and began to wonder how to develop an open source software equivalent to ChatGPT. For those who have not yet taken action, here is a reference example this time. The project (PaLM RLHF) we will introduce below implements such a function.
Project address: https://github.com/lucidrains/PaLM-rlhf-pytorch
This project implements RLHF ( human feedback reinforcement learning). Basically the same as ChatGPT, the difference is that PaLM is used. PaLM is a large language model with 540 billion parameters trained on Google's general AI architecture "Pathways". RLHF is ChatGPT's introduction of "manually labeled data reinforcement learning" (RLHF) on the basis of the GPT 3.5 series of models to continuously fine-tune the pre-trained language model, aiming to allow the large language model (LLM) to learn to understand human commands and learn Give the optimal answer based on the given prompt.
If you want to know more about RLHF, you can refer to: https://huggingface.co/blog/rlhf
As netizens said: "In the field of AI, every time there is a special project Breakthrough, developers will soon reproduce an open source version."
However, the project currently only contains Training architecture and code, no pre-trained weights. In the instructions for use, the document also shows that PaLM must be trained first.
Some netizens also expressed concern about this, saying: This is not an out-of-the-box project, it is just a structure, like Like the shell, it requires expensive overhead to train. No organization can train PaLM like Google.
Some netizens said: "It is very bad not to have pre-trained weights. The official needs to release at least 50% of the sparse weights, and let developers train the rest by themselves. It’s the best choice."
However, some netizens said they would try it:
Let’s take a look below See how this project works.
$ pip install palm-rlhf-pytorch
First train PaLM, just like any other autoregressive transformer.
import torch from palm_rlhf_pytorch import PaLM palm = PaLM( num_tokens = 20000, dim = 512, depth = 12 ).cuda() seq = torch.randint(0, 20000, (1, 2048)).cuda() loss = palm(seq, return_loss = True)loss.backward() # after much training, you can now generate sequences generated = palm.generate(2048) # (1, 2048)
The reward model is then trained using curated human feedback. In the original paper, it was not possible to obtain a fine-tuned reward model from a pretrained transformer without overfitting. The project authors provide the option to use LoRA for fine-tuning.
import torch from palm_rlhf_pytorch import PaLM, RewardModel palm = PaLM( num_tokens = 20000, dim = 512, depth = 12, causal = False ) reward_model = RewardModel( palm, num_binned_output = 5 # say rating from 1 to 5 ).cuda() # mock data seq = torch.randint(0, 20000, (1, 1024)).cuda()prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response labels = torch.randint(0, 5, (1,)).cuda() # train loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)loss.backward() # after much training reward = reward_model(seq, prompt_mask = prompt_mask)
Finally pass the transformer and reward model to RLHFTrainer.
import torch from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer # load your pretrained palm palm = PaLM( num_tokens = 20000, dim = 512, depth = 12 ).cuda() palm.load('./path/to/pretrained/palm.pt') # load your pretrained reward model reward_model = RewardModel( palm, num_binned_output = 5 ).cuda() reward_model.load('./path/to/pretrained/reward_model.pt') # ready your list of prompts for reinforcement learning prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts # pass it all to the trainer and train trainer = RLHFTrainer( palm = palm, reward_model = reward_model, prompt_token_ids = prompts ) trainer.train(num_episodes = 50000) # then, if it succeeded... # generate say 10 samples and use the reward model to return the best one answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
The above is the detailed content of Fast enough! The popular ChatGPT equivalent open source project is here, netizen: I'm worried that I won't be able to run it. For more information, please follow other related articles on the PHP Chinese website!