链载Ai

标题: OpenAI o1 技术初探2:使用MCTS增强推理能力(基于代码实践的解读) [打印本页]

作者: 链载Ai    时间: 5 小时前
标题: OpenAI o1 技术初探2:使用MCTS增强推理能力(基于代码实践的解读)

在o1的整体框架篇中(https://zhuanlan.zhihu.com/p/773907223),我们从现有开源的论文和代码中(https://github.com/hijkzzz/Awesome-LLM-Strawberry),抽象出了o1可能的技术实现路径,如下图:

这里对于这张框架图我们不再做赘述,详情可以参见上面《框架篇》的文章链接。

我们之前说过,这是一张高度抽象的框架图,旨在说明o1官方技术报告中提到的“把更多算力花在inference阶段上,以提升模型的逻辑推理能力”的含义。而从本文开始,我们将以具体的算法去扩展这张框架图的细节。

今天我们要具体扩展的,就是框架图中的Inference部分(黄色块),从框架图可知,Inference部分一般有两个作用:

所以,Inference模块可以看作是o1实现中的一块积木。当你理解这块积木的目的、以及一些可能的实现方法后。你就可以按需要灵活把它组装在你心目中o1的任何一个环节。在网上关于o1的资料中,我们可能经常会看见“MCTS,self-play”这样的关键词,它其实就是这块黄色积木的一种实现方式。不过笔者认为,o1走的不是纯靠优化inference的路线(即上图中的framework1),更可能走的是post-training + inference路线(即上图中的framework3,因为o1的技术报告中提过它把算力也花在了RL阶段上)。但是无论如何,了解这块积木的实现总是必要的。

在这篇文章中,我们将以微软在今年开源的rStar这个工作为例(https://github.com/zhentingqi/rStar),全面从源码出发,来详细看下MCTS技术是如何运用在nlp的逻辑推理任务上的(毕竟我们对MCTS的主要了解都来自AlphaGO,我们肯定非常好奇它要如何运作在自然语言上,特别是这个前提下它的搜索空间是什么)。阅读本文不需要任何MCTS先验知识,文中会循序渐进地做介绍

一、为什么选择rStar

rStar的目的同样是提升模型的逻辑推理能力,但是它走的是上图中的framework1,也就是纯靠inference的搜索优化来实现目标,同时它选择的是MCTS而非PRM + search methods的方法。rStar作出这样选择的原因如下:

正因为rStar走的是纯Inference的路线,所以更便于我们从”一块积木”的视角来理解框架图中的黄色块。同时,利好小模型的场景也更适合资源有限的我们。最后,当然是rStar的代码完全开源,方便我们一探所有的细节,少一些自己的想象(rStar的论文其实写得比较精简,少了很多细节的描述,也一定程度上造成代码不太好读)

二、按照人的思考方式构造一棵搜索树

这里我们先不谈MCTS的任何概念,我们只看:对于某个问题,你会采用什么样的思维链来解决它?

假设我们有一个简单的问题:

user_question:
Ifthereare3carsintheparkinglotand2morecarsarrive,howmanycarsareintheparkinglot?

为了解决它,我们可能有如下思考方式(所有的思考方式都以字母A开头,表示Action)

2.1 A1(propose a one-step-thought)

我们会做过程的拆解,每次提出一个推理step,直到生成最后的答案。我们记这种思考方式为A1。例如:

A1(proposeaone-step-thought)

###Instruction:
Ifthereare3carsintheparkinglotand2morecarsarrive,howmanycarsareintheparkinglot?

###Response:
Let'sthinkstepbystep.
Step1:Startwiththenumberofcarsthatarealreadyintheparkinglot,whichis3cars.
Step2:Addthenumberofcarsthatarrive,whichis2cars.
Step3:Addthenumberstogether.thereare3cars+2cars=5carsintheparkinglot.
Step4:Theansweris5.

观察上面的steps,我们会发现:

2.2 A2(propose the remaining thought steps)

对于一些简单的问题,我们可能并不会步步思考。我们会一次性通过一些简单的推理后直接给出答案,例如:

###Instruction:
Ifthereare3carsintheparkinglotand2morecarsarrive,howmanycarsareintheparkinglot?

###Response:
Let'sthinkstepbystep.Thereareoriginally3cars.2morecarsarrive.3+2=5.Theansweris:5.

2.3 A3 (propose next sub-question along with its answer)

有时候,我们会把原始问题拆解成很多子问题,然后回答一个个子问题,最终给出答案,例如:

Question1:Ifthereare3carsintheparkinglotand2morecarsarrive,howmanycarsareintheparkinglot?
Question1.1:Howmanycarsarethereintheparkbefore?
Answer1.1:Thereare3carsintheparkbefore.
Question1.2:Howmanycarsarrivethen?
Answer1.2:2morecarsarrive.
Question1.3:Nowwecananswerthequestion:howmanycarsareintheparkinglot?
Answer1.3:Thereare3+2=5carsintheparkinglotnow.Theansweris5.

其中,Question1是原始问题,其余是拆解的子问题。其中,Question 1.3属于终结类型的子问题,因为回答它就等于回答了最终答案。这种拆解子问题的方式更适合用来解决困难问题,我们的例子比较简单,这里只是展现出一个形式。

2.4 A4 (Answer the sub-question again)

这种方式将和A3一起配套使用,例如,对于A3的Question1.1,你可能并不确定Answer1.1是否正确,这时你想重新再思考一次Answer1.1的答案。由于此时你只是对某一个子答案做修正,因此你可能采用A2(propose the remaining thought steps)的方式,做一些简单的推理,重新取得Answer1.1。此时相当于把Answer1.1用A2例子中的输出结果进行替代,这里不再给出具体例子。

2.5 A5(Rephrase the question/sub-question)

有时我们在做题时,通常会在大段的原始题目描述中,把关键信息提取出来,例如:condition1..., condition2等等。我们可以先通过这种方式改写原始题目/子题目,然后再做回答。这个比较好理解,同样也不再给出具体的示例。

2.6 整合:构造一颗搜索树

总结一下,目前为止,我们按照人类的思维方式,总结出了人类解决一个问题时可能采用的5种方法:

在代码操作中,我们会按2.1~2.5的示例,构造相应的prompt来指示模型执行不同的动作。下图给出了A1的prompt示例,更多例子大家可以参见源码中rStar/prompts部分:

当人解决问题时,可能会根据问题的难度,决定不同的解决模式,但是当我们采用模型进行搜索时,模型是很难预知问题难度的,所以我们总是希望:模型能够尽可能地把这些解决方式(Action)都探索一遍。

那么接下来,我们就配合着rStar的源码,一起来看下这棵搜索树长什么样子(这里我们不使用论文中的图,因为它缺少了太多细节,我们直接从源码出发重新绘制):

我们先看一些基本信息:

接下来我们来看图中的更多细节:

我们从根结点(第0层)出发,根结点是用户的原始问题,对于根结点来说:

接下来我们从第1层出发,以第1层为例,探索下不同类型的结点可以生成什么类型的子节点,以及最终可能的leaf node类型。只要搞清楚了第1层,其余层就可以类推了。

总结一下,到目前为止我们已经解决了:

但是,仍有一些重要但未解的问题:

为了解决这两个问题,现在我们可以请出MCTS这个算法了

三、使用MCTS搜索最佳推理路径

3.1 使用rollout构造搜索树

对于模型来说,现在它将从原始问题出发,构造一棵搜索树。我们先来看从根结点出发,模型构造搜索树的过程:

对于根结点来说:

这样一轮select + expand + simulate + backprop的步骤,就称为1次rollout。不难发现,在1次rollout过后,我们构造出了一部分搜索树(这里我们先只谈构造,不谈搜索,大家不要着急)

接下来我们执行第2轮rollout,继续构造我们的搜索树(这里不再画图了,我们直接从1st rollout的图例中想象一下):

这里额外再提一句,生成搜索树的每一层时,我们都需要用前面所有层的推理步骤作为上文,传递给模型做生成,大家可以自行阅读源码找到构造上文的更多细节,这里不再额外介绍。

好,到这里为止我们已经理清单轮rollout的概念了,以此循环往复,在执行若干轮rollouts(代码默认值为16)后,我们就有一棵相对完整的搜索树了,接下来我们就可以基于这棵树去找到一条最佳的推理路径了。但是在介绍具体的搜索方法之前,让我们再来看看,如何计算一个结点的UCT值(UCT值越大,该结点被探索的价值越大)。

3.2 计算结点的UCT值

一个结点的UCT值计算方式如下:

什么样的结点更具被访问的价值呢?从直觉上说,平均reward越大的结点,表现越好,应该更具访问价值。这就是Q/N在做的事情。而这一部分也被称为利用(exploit),也即我们直接利用当前的结点价值数据做决策。

但是,如果一个结点被访问的次数比较少(比如它的父结点被访问了几百次,它才被访问几次而已),这说明这个结点所在的路径可能有更多的“宝藏”还没被我们发现,因此我们也应该给这些结点更多的机会。这就是c*sqrt(N_parent/N)在做的事情。而这一部分也被称为探索(explore),也即我们给访问次数较少的路径更大的被探索的机会。

而人为设置的探索权重c,就起到控制explore和exploit程度的作用。一般来说,我们会对c采用一些"退火策略“即:

理解这一点,阅读代码中的相关部分就不难啦。

3.3 搜索最佳路径

(1)直接从树中搜索

在若干轮rollouts后,我们终于有了一棵相对完整的搜索树了,那么现在对于一个原始问题(根结点),我们该选择一条最佳的路径,帮助我们找到合理的推理过程和答案呢?

(2)使用discriminator

在3.3(1)中,我们讲解了直接从构造好的搜索树中选择最佳路径的方法。但是在rStar中,还提供了另一种巧思:借助一个discriminator。也就是我们构造的搜索树相当于一个generator,我们使用discriminator从generator的结果中找到最可信的那个,这和我们熟知的GAN非常相似。其中generator和discriminator都是小模型,但是不同的小模型。

我们来看详细的过程:

四、总结

在本文中,我们以rStar为例,从代码级别的角度,给出了o1(可能的)实现框架中Inference这块积木的一个实现方法。在写这篇文章时,我本来想放一些源码和注释的,但是考虑到它在公众号里太占篇幅,可读性不高,所以没有放出来。但是源码中最精华的部分已经在前面的讲解中了,可以大大降低大家读源码的难度。

有了对MCTS如何运用在nlp任务上的一些初步理解,接下来我们就可以按自己的兴趣,广泛探索这块黄色积木的各种实现方式啦(其实本质上都做得差不多)。在后面的系列中,我们将继续对框架做拆解,加入更多的积木。







欢迎光临 链载Ai (https://www.lianzai.com/) Powered by Discuz! X3.5