>  기사  >  기술 주변기기  >  PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

王林
王林앞으로
2023-04-07 14:51:031144검색

수학적 추론은 인간 지능의 핵심 능력이지만, 추상적 사고와 논리적 추론은 여전히 ​​기계에게 큰 도전 과제입니다. GPT-3 및 GPT-4와 같은 대규모 사전 훈련된 언어 모델은 텍스트 기반 수학적 추론(예: 수학적 단어 문제)에서 상당한 진전을 이루었습니다. 그러나 현재 이러한 모델이 표 형식 데이터와 같은 이질적인 정보와 관련된 보다 복잡한 문제를 처리할 수 있는지 여부는 불분명합니다. 이러한 격차를 메우기 위해 UCLA와 Allen Institute for Artificial Intelligence(AI2)의 연구원들은 올바른 답을 얻기 위해 텍스트와 표 형식 데이터에 대한 수학적 추론을 모두 요구하는 38,431개의 개방형 도메인 문제로 구성된 데이터 세트인 TabMWP(표 형식 수학 단어 문제)를 출시했습니다. 답변. TabMWP의 각 질문은 구조화된 형식의 이미지, 텍스트 또는 테이블을 포함하는 컨텍스트와 연결됩니다.

연구원들은 TabMWP의 Few-shot GPT-3를 포함하여 다양한 사전 훈련된 모델을 평가했습니다. 기존 연구에서 밝혀진 바와 같이 Few-shot GPT-3는 상황에 맞는 예시 선택에 크게 의존하므로 예시를 무작위로 선택할 때 성능이 상당히 불안정해집니다. 이러한 불안정성은 TabMWP와 같은 복잡한 추론 문제를 처리할 때 더욱 심각합니다. 이 문제를 해결하기 위해 저자는 강화학습에서 사례 선택을 상황적 산적 문제로 변환하고 Policy Gradient를 사용하여 소량의 사례에서 최적을 선택하는 방법을 학습하는 정책 네트워크를 학습시키는 PromptPG 방법을 제안했습니다. 훈련 데이터 -컨텍스트 예시. 실험 결과는 그들이 제안한 PromptPG 방법이 질문에 대한 답변에서 최적 기준(Few-shot CoT GPT-3)을 5.31% 초과하고, 그들의 방법이 무작위로 선택된 상황 내 예제에 비해 문제를 크게 줄인다는 것을 보여줍니다. 이 유형의 방법의 안정성.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때


  • 논문 링크: https://arxiv.org/abs/2209.14610
  • 코드 링크: https://github.com/lupantech/PromptPG
  • 프로젝트 홈페이지: https://promptpg.github.io
  • 데이터 시각화: https://promptpg.github.io/explore

1 TabMWP 데이터 세트

다음은 TabMWP 데이터 세트의 두 가지 예에서. 하나는 숫자 답변이 포함된 자유 텍스트 질문이고, 다른 하나는 텍스트 답변이 포함된 객관식 질문입니다. 보시다시피, 각 질문은 단계별 추론을 포함하는 솔루션을 제공합니다. TabMWP의 문제를 해결하려면 시스템이 테이블 조회와 다단계 수학적 추론을 모두 수행할 수 있어야 합니다. 아래 그림의 예를 들어 "(트레이시가 세 가지 종류의 빵을 산다면) 그녀가 지출할 금액은 얼마입니까?"에 대답하려면 먼저 표에서 세 가지 종류의 빵에 해당하는 가격을 찾은 다음 비용을 계산해야 합니다. 각 종류의 빵을 구입하고 그 비용을 합산하여 최종 비용을 구합니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

아래 표의 통계에서 볼 수 있듯이 TabMWP 데이터 세트에는 38,431개의 표 형식 수학 문제가 포함되어 있습니다. 질문의 74.7%는 자유 텍스트 질문이었고 25.3%는 객관식 질문이었습니다. TabMWP는 총 28,876개의 고유 질문, 6,153개의 고유 답변, 35,442개의 고유 답변을 보유하고 있어 질문 분포의 다양성이 풍부함을 나타냅니다. 질문의 평균 길이는 22.1 단어, 답변의 평균 길이는 49.5 단어로 TabMWP의 어휘 풍부도를 나타냅니다. TabMWP의 특징은 각 문제에 테이블 컨텍스트가 수반되며, 테이블 컨텍스트 없이는 문제를 해결할 수 없다는 것입니다. TabMWP에는 총 37,644개의 다양한 테이블이 있으며, 평균 테이블 크기는 5.9행과 2.2열, 12.9셀, 최대 54셀입니다. 이러한 통계는 TabMWP의 테이블에도 다양성이 풍부하다는 것을 보여줍니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

TabMWP 데이터 세트에는 두 가지 질문 유형과 다섯 가지 답변 유형이 있습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

TabMWP의 모든 질문에는 이미지, 반구조화된 텍스트, 구조화된 세 가지 형식으로 표현되는 표 형식의 컨텍스트가 있습니다. 이는 다양한 유형의 추론 모델을 개발할 가능성을 열어줍니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

기존 데이터 세트와 비교하여 TabMWP는 질문에 답하기 위해 테이블 ​​이해와 수학적 추론 능력이 모두 필요합니다. 또한 TabMWP는 각 질문에 대해 상세한 다단계 추론 프로세스를 갖추고 있으며 이는 데이터 세트 크기, 테이블 유형, 질문 유형 및 답변 유형에서 분명한 이점을 가지고 있습니다. 이 문서의 지식을 최대한 활용하면 TabMWP는 개방형 도메인 표 형식 시나리오의 최초 수학적 추론 데이터 세트입니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

2. PromptPG 방법

수학적 단어 문제 해결에서 GPT-3과 같은 대규모 사전 학습 모델의 성공을 고려하여 저자는 먼저 Few-shot GPT-3을 사용하여 TabMWP에 대한 벤치마크를 설정했습니다. . 그들은 훈련 세트와 테스트 예시에서 일부 상황별 예시를 무작위로 선택하여 GPT-3가 답을 예측하도록 유도하는 프롬프트를 형성합니다. 그러나 최근 연구에 따르면 무작위 선택을 기반으로 한 이러한 종류의 소수 학습은 상황에 따른 예시 선택에 따라 매우 불안정하게 수행될 수 있습니다. 다양한 유형과 형식의 테이블이 포함된 TabMWP와 같은 복잡한 추론 문제를 처리할 때는 무작위 선택이 훨씬 덜 효과적일 수 있습니다.

이 문제를 해결하기 위해 저자는 향상된 방법을 제안합니다. Policy Gradient를 통한 신속한 학습, PromptPG라는 소량의 학습 데이터에서 상황에 맞는 예를 선택하는 학습입니다. 그림 2에서 볼 수 있듯이 정책 네트워크는 후보 풀(후보 예제)에서 상황에 가장 적합한 예제를 찾는 방법을 학습하며, 최적화 목표는 GPT와 상호 작용할 때 주어진 훈련 예제(훈련 예제)의 예측을 최대화하는 것입니다. -3 환경 상. 예시 선택을 위한 정책 네트워크는 고정 매개변수를 기반으로 하는 BERT 언어 모델과 학습 가능한 매개변수가 있는 단일 레이어 신경망입니다. 최적화 학습을 완료한 후 PromptPG는 다양한 시험 문제에 대한 후보 사례에서 다양한 최적 사례를 동적으로 선택하여 GPT-3의 추론 성능을 극대화할 수 있습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음은 PromptPG의 학습 알고리즘입니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

3. 실험 및 분석

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

사전 훈련 및 미세 조정

표 3은 TabMWP 데이터 세트에 대한 PromptPG의 결과와 다양한 벤치마크를 비교합니다. 유사한 매개변수 양을 가진 표 형식 데이터에 대한 사전 훈련으로 인해 TAPEX가 UnifiedQA보다 더 나은 성능을 발휘하는 것을 볼 수 있습니다. TAPEX와 UnifiedQA 모두 모델의 매개변수 수를 늘리면 예측 정확도가 향상될 수 있습니다. 또한 TabMWP에서 모델을 미세 조정하면 예측 정확도를 크게 향상시킬 수도 있습니다.

대규모 언어 모델

GPT-3은 미세 조정 없이도 미세 조정된 UnifiedQA 및 TAPEX 모델과 유사한 정확도를 달성할 수 있습니다(제로샷 GPT-3). Few-shot GPT-3 모델이 상황 내 예시 2개를 GPT-3 힌트로 무작위로 선택하면 Zero-shot GPT-3에 비해 0.17% 더 향상될 수 있습니다. Few-shot GPT-3가 최종 답변(Few-shot-CoT GPT-3)을 생성하기 전에 여러 중간 단계를 생성함으로써 연구원들은 62.92%의 정확도로 최적의 기준 모델을 얻을 수 있었습니다.

PromptPG

본 글에서 제안하는 PromptPG는 문맥 내 예시를 무작위로 선택하는 것과는 달리, 보다 적절한 문맥 내 예시를 선택하기 위해 Policy Gradient를 통해 정책 네트워크를 훈련시켰으며 TabMWP에서 가장 높은 예측 결과(68.23%)를 달성했으며 평균 예측을 달성했습니다. 정확도 비율은 최고 기준 모델(Few-shot-CoT GPT-3)을 5.31% 초과합니다. 특히 PromptPG는 거의 모든 문제 유형, 답변 유형, 문제 난이도에 대한 예측 정확도에서 탁월한 성능을 보여줍니다. 그럼에도 불구하고 PromptPG는 90.22%의 휴먼 성능으로 아직 개선의 여지가 많습니다.

Ablation Experiment

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

표 4는 TabMWP의 모든 입력 요소(질문 텍스트, 표 정보, 옵션 정보)가 질문에 올바르게 답하는 데 중요하다는 것을 보여줍니다. 모든 문제 요소를 입력 정보로만 사용했을 때 Zero-shot GPT-3는 상대적으로 가장 높은 평균 예측 정확도(59.50%)를 달성했습니다.

다른 사례 선택

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

비교 실험으로 연구자들은 다른 사례 선택으로 다른 방법도 비교했습니다. 표 5에서 볼 수 있듯이 테스트 문제와 동일한 질문 유형 또는 답변 유형을 선택하면 모델이 보다 관련성이 높은 예를 찾고 답변의 정확도를 높이는 데 도움이 될 수 있습니다. 가장 복잡한 예를 선택한다고 해서 답변 정확도가 지속적으로 향상되는 것은 아닙니다. 후보 예제 중에서 가장 좋은 두 예제를 고정적으로 선택하면 정확도가 약간 향상되고 분산이 줄어들 수 있습니다. 의미론적으로 테스트 문제에 가장 가까운 예제를 선택하면 PromptPG 방법에 가장 가까운 정확도를 얻을 수 있습니다. 전반적으로 PromptPG는 예측 정확도를 향상하고 예측 분산을 줄이는 이점을 완전히 입증했습니다.

아래 그림은 PromptPG 선택 예시와 최종 예측 결과를 보여줍니다. PromptPG 방법은 시험문제와 수학적 능력이 유사한 예시를 선택함으로써 Few-shot GPT-3의 추론 성능을 향상시킬 수 있음을 알 수 있다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

성공적인 예측의 예

다음은 자유 텍스트 질문에 대한 PromptPG의 정답입니다. 이 질문에서는 평균을 찾기 위해 표에 8개의 숫자를 더하고 나누어야 합니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

아래 예에서 모델은 세금 신고서를 이해하고 세금 공제 후 급여를 계산하도록 요청받습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음은 객관식 문제에 대한 PromptPG의 정확한 예측을 보여줍니다. 주어진 테이블에는 총 9개의 행과 6개의 열이 있습니다. 모델은 테이블에서 대상 셀을 성공적으로 찾고 다단계 추론을 수행하여 정답을 예측합니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

아래 예에서 모델은 Ariana가 돈이 충분한지 확인하기 위해 예산과 총 비용을 비교해야 합니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

예측 실패의 예

다음은 자유 텍스트 문제에 대한 PromptPG의 잘못된 예측을 보여줍니다. 모델이 로즈 쿼츠의 잘못된 가격을 검색하여 세 가지 항목의 총 비용을 잘못 계산했습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음 예에서 질문은 추상적인 줄기와 잎 표를 제공합니다. 모델은 이 도메인별 테이블을 이해할 수 없었고 잘못된 답을 얻을 수 있는 고급 논리적 추론 기능이 부족했습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음 예를 보면 기존 모델에는 숫자 정렬 기능이 없는 것 같습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음 예에서는 질문에 언급된 현재 시간과 정확히 일치하는 시간이 표에 나타나지 않으므로 모델이 다음 정류장의 출발 시간을 정확하게 찾을 수 없습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

다음 예에서는 모델이 긴 숫자 계열에 대한 산술 연산을 정확하게 완료하는 것이 어렵습니다.

PromptPG: 강화 학습이 대규모 언어 모델을 만날 때

4. 결론 및 전망

저자는 표 형식의 수학적 문제를 해결하기 위한 최초의 대규모 데이터 세트인 TabMWP를 제안했습니다. TabMWP에는 2가지 질문 유형과 5가지 답변 유형을 포함하여 38,431개의 개방형 도메인 질문이 포함되어 있으며 각 질문에는 다단계 솔루션 프로세스가 표시되어 있습니다. 저자는 최첨단 QA 및 TableQA 방법을 사용하고 사전 훈련 및 미세 조정 설정에서 TabMWP에 대한 포괄적인 실험을 수행했으며 대규모 사전 훈련된 언어 모델 GPT-3을 사용하여 평가했습니다. 저자는 또한 GPT-3 모델을 프롬프트하기 위한 훈련 데이터에서 최적의 인스턴스를 선택하기 위해 Policy Gradient 학습을 사용하는 새로운 강화 학습 방법인 PromptPG를 제안합니다. 실험 결과에 따르면 PromptPG는 기존 기준선보다 훨씬 뛰어난 성능을 발휘하고 무작위 선택에 비해 예측의 성능 불안정성을 줄이는 것으로 나타났습니다.

위 내용은 PromptPG: 강화 학습이 대규모 언어 모델을 만날 때의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제