<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="cn"><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="https://fatescript.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="https://fatescript.github.io/" rel="alternate" type="text/html" hreflang="cn"/><updated>2025-04-09T00:52:02+00:00</updated><id>https://fatescript.github.io/feed.xml</id><title type="html">Fatescript</title><subtitle>Fatescript&apos;s website </subtitle><entry><title type="html">那些年，我们没想过的数值稳定算法</title><link href="https://fatescript.github.io/blog/2025/numerical-stability/" rel="alternate" type="text/html" title="那些年，我们没想过的数值稳定算法"/><published>2025-04-06T15:59:00+00:00</published><updated>2025-04-06T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2025/numerical-stability</id><content type="html" xml:base="https://fatescript.github.io/blog/2025/numerical-stability/"><![CDATA[<p>深度学习模型的训练在本质上是通过一系列复杂的数值计算组合去逼近一个极度复杂的函数$f$，而机器本身对于数值的表达就带有精度上的误差。真实的世界往往会遭遇这样的情况：一个微小的浮点误差，导致了最终的梯度爆炸或消失，抑或是模型无法收敛，而这让找寻原因变得异常困难。</p> <p>对于大部分炼丹师来说，结构上的更改和巧思显然令人着迷，但是本文我们不讨论为什么attention要除以 $\sqrt{d_k}$，抑或是GPT为什么用pre-norm而不是post-norm这种，这些属于模型结构上的巧思，而这篇文章关注的是数学和工程层面的技巧。数值分析之父 James H. Wilkinson曾经表达过这样的观点：“数值计算的主要挑战在于如何管理误差的传播(propagation of rounding errors)”。而这些管理技巧又可以分成两个部分，即：</p> <ul> <li>使得数值计算的方法给出的结果更接近真实结果</li> <li>在计算结果不能更优的基础上，防止误差进一步扩散</li> </ul> <h3 id="前置概念">前置概念</h3> <p>为了防止你在后面看的太过云里雾里，脑海中想象不到在底层到底发生了什么，我们稍微温习一下<a href="https://en.wikipedia.org/wiki/IEEE_754">IEEE754标准</a>以及介绍一下衍生的数值问题场景。</p> <p>IEEE754标准在表示浮点数的时候，主要分为三个部分：符号位、指数位和尾数位。我们用一个32位的浮点数如何表示79进行举例：</p> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/float_num-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/float_num-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/float_num-1400.webp"/> <img src="/assets/blog/float_num.jpg" class="img-fluid rounded z-depth-1" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>​ 注意这里exponent是8位的，所以指数对应的数值是6 + 127 = 133，是规格化的浮点数。非规格化的数据我们暂不讨论，此处只是快速回顾一下浮点数的表示方法。</p> <p>不同精度的浮点数只是exponent和mantissa的位数不同，比如double/float64就是11位exponent和52位mantissa，float16是5位exponent和10位mantissa，bfloat16是8位exponent和7位mantissa，而float8有两种表示方法：e4m3（4位exponent，3位mantissa）和e5m2（5位exponent，2位mantissa）。</p> <p>了解了浮点数如何被表征，我们就可以讨论一下数值计算中常见的问题了，基本上可以分为三类：</p> <ol> <li><strong>overflow</strong>：很大的数据被round成了 $\infty$ 或者 $-\infty$ ，比如在IEEE754标准里面double的最大数值在1e308这个量级（可以用 <code class="language-plaintext highlighter-rouge">numpy.finfo(np.float64).max</code> 校验），如果你在python里用 1e308 + 1e308，得到的结果就是inf了</li> <li><strong>underflow</strong>：接近0的结果被round成了0，比如double的最小的正数是 $2^{-1074}$ ，这个-1074来源于 -1023(11 bit exp) - 51(52 bit mantissa)，大约为5e-324，所以如果在python里使用<code class="language-plaintext highlighter-rouge">9e-324 - 8e-324</code>，得到的结果就是0了，这里就是发生了underflow</li> <li><strong>loss of precision</strong>：因为浮点数的表示方法在数轴上是不均匀的（或者说是分段均匀的），所以天然存在一个问题：只能近似表示某些数据，这就涉及到近似表示带来的精度损失。比如在python里面，计算0.1 + 0.2，得到的结果会是0.30000000000000004，0.1 + 0.2 == 0.3返回的结果也会是False，也就是发生了精度损失。如果你进一步用 <code class="language-plaintext highlighter-rouge">struct</code> 查看hex数值，0.1 + 0.2的结果是0x3FD3333333333334，而0.3则是0x3FD3333333333333（这个数值在数轴上离0.3更近）</li> </ol> <h3 id="数值稳定的解法">数值稳定的解法</h3> <p>我有一位做HPC（high performance computing）的朋友总结过hpc领域提速的两板斧：<strong>减少计算量</strong>（比如卷积的<a href="https://arxiv.org/pdf/1509.09308">WinoGrad算法</a>）、<strong>减少IO</strong>（比如<a href="https://arxiv.org/pdf/2205.14135">flash attention</a>）。对应的，为了提高算法的数值稳定性，也有一些基本的解决套路，总体上，可以归纳为下面四类策略：</p> <ol> <li> <p><strong>重写数学公式</strong></p> <p>很多时候，数学公式在理论上是等价的，但在数值计算中可能存在极大的稳定性差异。我们在下一个部分给出了除法运算的例子，就是典型的重写公式（改变运算的组合顺序），本质上不改变算法逻辑。</p> </li> <li> <p><strong>使用其他算法</strong></p> <p>有些算法虽然数学上是等价的，但在实际计算中表现差异很大。例如在求方差的时候，既可以按照方差的定义先求期望，再求方差；也可以利用 $\text{Var}(X) = \text{E}(X^2) - \text{E}(X)^2$ 求方差。我们会在下一部分的normalization中对这两种方法进行详细的分析，这里我们只需要记住：不同的算法会有不同的数值稳定性。</p> </li> <li> <p><strong>提高精度或改变数值类型</strong></p> <p>默认情况下，模型使用的是 <code class="language-plaintext highlighter-rouge">float32</code> / <code class="language-plaintext highlighter-rouge">float16</code> 来进行训练，但在关键节点上，特别是梯度累加、参数更新等环节，如果使用低精度可能会导致误差累积甚至训练不收敛。所以有时候为了达到更好的稳定性，经常会在某些模块中采用更高的精度，或者临时将变量转换成高精度计算再转换回来。比如在混合精度训练中，往往会做loss scaling，或者在某些操作上autocast到 <code class="language-plaintext highlighter-rouge">float32</code> 以保证训练稳定。</p> </li> <li> <p><strong>限制输入范围</strong></p> <p>这个方法应该是短期walk around的时候最常用的方法，比如定位到出现出现不稳定的具体算子，然后对输入或者输出做一下clip，或者加一个epsilon，类似clip(x, min=1e-5) 或者 x = x + 1e-5这种写法。通常对于一些除以极小值或者log一个极小值的case能起到效果。</p> </li> </ol> <h3 id="那些年我们错过的算子">那些年，我们错过的算子</h3> <h4 id="div-forward">div forward</h4> <p>相信很多人很难理解为什么除法会存在数值稳定性的问题（因为这个实现基本来自硬件的指令），但是实际上在深度学习框架里面，数据的类型是很多样的，仅仅是data type，就衍生出来了<a href="https://github.com/pytorch/pytorch/blob/c65de03196ae3dbeb67ef38d43c4639b85a60ce4/aten/src/ATen/AccumulateType.h#L16">accumulate type</a>、saclar type（单个的数值，比如tensor + 5，5就是scalar。scalar在运算的时候会<a href="https://github.com/pytorch/pytorch/blob/c65de03196ae3dbeb67ef38d43c4639b85a60ce4/aten/src/ATen/ScalarOps.h#L31">转成tensor</a>）和promote type等概念，而除法的不稳定性，就来自于promote type。</p> <p>promote type是指两个不同类型的tensor（包含scalar，因为scalar可以隐式转成tensor）运算之后的结果应该是什么类型的。pytorch在进行数值计算（比如加减乘除）的时候，会对dtype采用下面的一个promote逻辑（参考 <a href="https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype">doc</a>）：</p> <ol> <li>如果scalar或者zero-dim tensor（zero dim tensor就是在scalar上套了一层tensor，比如torch.tensor(1.2)就是一个zero-dim tensor）比和它运算的tensor的精度高一个层级（complex &gt; floating &gt; integral &gt; boolean），那么结果的数据类型就是scalar的类型。比如int tensor 和浮点数做加法，结果就是float tensor；float tensor和complex做乘法，结果是complex tensor。</li> <li>如果scalar或者zero-dim tensor和参与运算的tensor是在一个精度层级下（比如都是float，但是一个是float32，一个是float16），那么返回的结果就是参与运算的tensor的dtype。比如float16的tensor和一个浮点数做加法，结果是float16的tensor；int16的tensor和python int做加法，结果是int16的tensor。</li> </ol> <p>python中的浮点数的精度是double；int则是通过大整数算法，以30bit（64bit系统）或者15bit（32bit系统）为一节的变长数组<a href="https://github.com/python/cpython/blob/e7980ba233bcbdb811e96bd5003c7d51a4e25155/Include/cpython/longintrepr.h#L64-L91">实现</a>。对于越界的int数据，torch会通过抛出异常进行制止（试试 <code class="language-plaintext highlighter-rouge">torch.tensor(1 &lt;&lt; 63)</code> ）。但是，浮点数和低精度的tensor的运算就没有那么直观了，比如下面的code改编自<a href="https://github.com/pytorch/pytorch/pull/41446">torch pr 41446</a>，不考虑promote type的情况下，很难理解为什么结果会是0。</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tensor</span><span class="p">([</span><span class="mf">3388.</span><span class="p">]).</span><span class="nf">half</span><span class="p">()</span>
<span class="n">scale</span> <span class="o">=</span> <span class="mf">524288.0</span>
<span class="nf">print</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="nf">div</span><span class="p">(</span><span class="n">scale</span><span class="p">))</span>  <span class="c1"># tensor([0.])
</span></code></pre></div></div> <p>当然上面的issue已经被修复了，修复的方法就是把<code class="language-plaintext highlighter-rouge">div(scale)</code> 换成 <code class="language-plaintext highlighter-rouge">mul(1 / scale)</code>。</p> <p><br/></p> <h4 id="div-backward">div backward</h4> <p>对于除法，除了forward过程中的promote type之外，在反向传播的过程也会引入数值稳定的问题。有趣的是，这个问题在<a href="https://github.com/tensorflow/tensorflow/pull/6562">tensorflow</a>和<a href="https://github.com/pytorch/pytorch/issues/43414">pytorch</a>里面都有对应的pr和issue讨论，而且早期的实现都是不够数值稳定的版本。</p> <p>对于除法运算 x / y，自动求导的结果是 $- \frac {x} {y^2}$ ，这也导致了早期的实现都是 <code class="language-plaintext highlighter-rouge">- x / (y * y)</code> 的形式，但是我们考虑接近0的数值y（比如1e-8）， y * y的结果往往很有可能出现underflow的问题；而对于比较大的数值y，y * y很有可能overflow。</p> <p>而如果先计算 x / y，则这个中间结果通常处在一个更良好的范围里，所以对于除法的反传函数， <code class="language-plaintext highlighter-rouge">- x / y / y</code> 通常是一个更加数值稳定的写法。这也是pytorch官方的<a href="https://github.com/pytorch/pytorch/blob/330c9577a3ce880c49475ca79e517b8741ff225b/torch/csrc/autograd/FunctionsManual.cpp#L648">实现方式</a>。</p> <p>如果想要这对两个写法有更直观的理解，可以试着跑一下下面的code</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span> <span class="mf">1e-8</span>

<span class="nf">print</span><span class="p">(</span><span class="n">a</span> <span class="o">/</span> <span class="p">(</span><span class="n">b</span> <span class="o">*</span> <span class="n">b</span><span class="p">))</span>  <span class="c1"># 9999999999999998.0
</span><span class="nf">print</span><span class="p">(</span><span class="n">a</span> <span class="o">/</span> <span class="n">b</span> <span class="o">/</span> <span class="n">b</span><span class="p">)</span>  <span class="c1"># 1e+16
</span></code></pre></div></div> <p><br/></p> <h4 id="prod">Prod</h4> <p>prod是一个tensor中的数值的累乘，也就是 $y = \prod_{i=1} x_i$，对于这个累乘中的任何一个元素 $x_i$ 求导，结果就是 $\text{grad} * y / x_i$ ，但是因为y是累乘结果，可能会存在非常大或者非常小的情况，因此和div一样，有underflow和overflow的风险，因此一个比较数值稳定的写法应该是 $\text{grad} * (y / x_i)$，这也是torch<a href="https://github.com/pytorch/pytorch/blob/330c9577a3ce880c49475ca79e517b8741ff225b/torch/csrc/autograd/FunctionsManual.cpp#L825-L826">官方的实现方式</a> 。</p> <p><br/></p> <h4 id="range">Range</h4> <p>range通常给出起始数值start、结束数值end，还有步长step，因为首先要确定range的size，这里就会涉及到数值稳定问题，<a href="https://github.com/pytorch/pytorch/commit/33cc71dc55db073ba46b065e24cff0d26156376f">这个torch commit</a>解决的就是这个问题，对应下面的两个写法：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">range_size_a</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
    <span class="k">return</span> <span class="nf">int</span><span class="p">((</span><span class="n">xmax</span> <span class="o">-</span> <span class="n">xmin</span><span class="p">)</span> <span class="o">/</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">range_size_b</span><span class="p">(</span><span class="n">xmin</span><span class="p">,</span> <span class="n">xmax</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span>
    <span class="k">return</span> <span class="nf">int</span><span class="p">((</span><span class="n">xmax</span> <span class="o">/</span> <span class="n">step</span> <span class="o">-</span> <span class="n">xmin</span> <span class="o">/</span> <span class="n">step</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
</code></pre></div></div> <p>考虑到浮点数的加减法经常会涉及到精度损失，借助精度损失的经典案例0.1 + 0.2 不等于 0.3，我们可以轻松构造出下面的case：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">step</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.1</span>
<span class="nf">print</span><span class="p">(</span><span class="nf">range_size_a</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">step</span><span class="p">))</span>  <span class="c1"># 3
</span><span class="nf">print</span><span class="p">(</span><span class="nf">range_size_b</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">,</span> <span class="n">step</span><span class="p">))</span>  <span class="c1"># 2
</span></code></pre></div></div> <p>毫无疑问，方法b会比方法a更加数值稳定。</p> <p><br/></p> <h4 id="线性插值">线性插值</h4> <p>线性插值 (Linear Interpolation，LERP)是计算两个数之间某个比例值的算法，在深度学习里面，比如数据增强里的<a href="https://arxiv.org/abs/1710.09412">mix-up</a>、模型<a href="https://github.com/arcee-ai/mergekit/blob/09bbb0ae282c6356567f05fe15a28055b9dc9390/mergekit/merge_methods/slerp.py#L94-L97">权重融合</a>、momentum update里是比较常见的一种方法，LERP函数的表达式是</p> \[\text{lerp}(a, b, t) = (1 - t) \cdot a + t \cdot b\] <p>其中a是起始值，b是结束值，t是插值因子（在0到1之间）。</p> <p>或许你看到这里会有一些疑问，这么简单的式子，感觉怎么写都不会有数值稳定性问题，但是实际上上面数学表达式的写法很容易导致运算过程不满足单调性。单调性是指当a大于（或小于）b的时候，那么更大的t插值出来的结果应该更大（或更小）。</p> <p>下面的例子就很好地表明采用数学表达式写法的问题，事实上也是<a href="https://github.com/pytorch/pytorch/pull/18871">torch pr 18871</a> 尝试修复的问题。</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">A</span><span class="p">,</span> <span class="n">B</span> <span class="o">=</span> <span class="mf">4000.0</span><span class="p">,</span> <span class="mf">4000.0</span>
<span class="n">t</span> <span class="o">=</span> <span class="mf">0.4247583667749129</span>  <span class="c1"># float.fromhex("0x1.b2f3db7800a39p-2")
</span><span class="nf">print</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">t</span><span class="p">)</span> <span class="o">*</span> <span class="n">A</span> <span class="o">+</span> <span class="n">t</span> <span class="o">*</span> <span class="n">B</span><span class="p">)</span>   <span class="c1"># 4000.0000000000005
</span></code></pre></div></div> <p>解决方法也很简单，将函数换成分段表示：</p> \[\text{lerp}(A, B, t) =\begin{cases} A + (B - A) \times t, &amp; \text{if } t &lt; 0.5 \\B - (B - A) \times (1 - t), &amp; \text{otherwise}\end{cases}\] <p>这个写法的精妙在于：</p> <ol> <li>端点匹配。也就是在t=0和t=1的时候可以取到A和B的值。如果A和B两个浮点数很接近，那么B-A有可能会underflow到0；或者A和B相差很大，B-A结果容易等于B（或者A），都不可能做到恰好取到A和B的值。</li> <li>当A和B相等的时候，显然lerp的结果始终为固定值。</li> <li>保证单调性。毫无疑问这个函数是分段单调的，关键在于在连接处是否是单调的。考虑刚刚好比0.5小的浮点数s（在python中可以用<code class="language-plaintext highlighter-rouge">0.5 - math.ulp(0.5) / 2</code> 得到这个值），对于任意正浮点数u，我们有 $u \times s &lt; u / 2$ 恒成立，所以这个函数是可以保证单调性的。</li> </ol> <p>看到lerp的实现，很容易让人联想到二分查找里求median的时候的<a href="https://en.wikipedia.org/wiki/Binary_search#Implementation_issues">数值问题</a>，也就是将 <code class="language-plaintext highlighter-rouge">median = (high + low) / 2</code> 改成 <code class="language-plaintext highlighter-rouge">median = low + (high - low) / 2</code> 。torch也有<a href="https://github.com/pytorch/pytorch/commit/56840f0a81e4460089740d50d3768f37e79a17fc">commit</a>修过类似的问题。</p> <p><br/></p> <h4 id="normalization">normalization</h4> <p>在normalization中，求数据的均值和方差是非常基础的操作。我相信大部分人都会觉得求均值和方差就是简单套用下面的公式：</p> \[\begin{aligned} \text{E}(X) &amp;= \frac{1}{n} \sum_{i=1}^{n} x_i \\ \text{Var}(X) &amp;= \frac{1}{n} \sum_{i=1}^{n} (x_i - \text{E}(X))^2 \\ \text{Var}(X) &amp;= \text{E}(X^2) - \text{E}(X)^2 = \frac{1}{n} \left( \sum_{i=1}^{n} x_i^2 - \frac{\left( \sum_{i=1}^{n} x_i \right)^2}{n} \right) \end{aligned}\] <p>在求方差的时候，如果采用标准的方差定义，需要先确定数据的均值，这样的话，需要遍历数据两次（two-pass）：第一次计算出均值，第二次计算出方差，这样对IO不是很友好。如果采用方法2计算方差，只需要遍历一次数据且能够在线更新（online update），但是这样会有非常明显的数值稳定问题，一个典型的case就是在小方差大均值的数据。</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1e8</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mf">1e8</span> <span class="o">+</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">1e8</span> <span class="o">+</span> <span class="mi">3</span><span class="p">]</span>
<span class="n">n</span> <span class="o">=</span> <span class="nf">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">mean</span> <span class="o">=</span> <span class="nf">sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span>
<span class="n">sum_x</span> <span class="o">=</span> <span class="nf">sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">sum_x_square</span> <span class="o">=</span> <span class="nf">sum</span><span class="p">(</span><span class="n">d</span> <span class="o">**</span> <span class="mi">2</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">x</span><span class="p">)</span>

<span class="n">unstable_var</span> <span class="o">=</span> <span class="p">(</span><span class="n">sum_x_square</span> <span class="o">-</span> <span class="p">(</span><span class="n">sum_x</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span>
<span class="nf">print</span><span class="p">(</span><span class="n">unstable_var</span><span class="p">)</span>    <span class="c1"># 0.0
</span><span class="n">stable_var</span> <span class="o">=</span> <span class="nf">sum</span><span class="p">((</span><span class="n">d</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">n</span>
<span class="nf">print</span><span class="p">(</span><span class="n">stable_var</span><span class="p">)</span>    <span class="c1"># 0.6666666666666666
</span></code></pre></div></div> <p>明明有方差，但是方差却被算成了0，核心原因就是非常大的数据做乘法之后很容易丢失精度。</p> <p>当然，还是有只需要一次遍历（one-pass）且数值稳定的算法的，就是<a href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm">welford</a>算法。</p> <p>既然是one-pass的在线更新，关键在于定义更新的方法。在第n次更新的时候，welford采取下面的做法（mean和M2初始化为0）：</p> \[\begin{aligned} \delta &amp;= x_n - \text{mean}_{n-1} \\ \text{mean}_n &amp;= \text{mean}_{n-1} + \frac{\delta}{n} \\ \text{M2}_n &amp;= \text{M2}_{n-1} + \delta \times (x - \text{mean}_n) \\ \text{Var}(X_n) &amp;= \frac{\text{M2}_n}{n} \end{aligned}\] <p>这个做法除了保证了方差的数值稳定性之外，在均值上的稳定性也更好。如果按照原始算法，每次都维护一个sum，在数据样本n变大之后，sum会越来越大，sum和 $x_n$之间的差值也越大，精度的损失也越多。</p> <p><br/></p> <h4 id="log1p">log1p</h4> <p>log1p(x)计算的是log(1 + x)的数值，这里主要是精度损失问题，比如1加上一个非常小的数值，1+x很容易被舍入为1，再叠加上log操作，会直接计算为0。而在 $(0, 1]$ 这个区间里面做log运算也是非常危险的：因为输入的变化可能非常微小，但是输出的变化却是在剧烈的震荡（导数说明了一切）。</p> <p>下面的例子展示了log1p的实现和原始实现的稳定性差异：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">math</span>

<span class="n">x</span> <span class="o">=</span> <span class="mf">1e-10</span>
<span class="nf">print</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">x</span><span class="p">))</span>   <span class="c1"># 1.000000082690371e-10
</span><span class="nf">print</span><span class="p">(</span><span class="n">math</span><span class="p">.</span><span class="nf">log1p</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>     <span class="c1"># 9.999999999500001e-11
</span></code></pre></div></div> <p>log1p算是一个广为人知的专为数值稳定实现的算子，我们参考<a href="https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/s_log1p.c#L16">glibc的实现</a>简单解释一下，感兴趣的同学可以自己点进去看源码。</p> <p>首先考虑到IEEE754对于浮点数的表示方式，可以把1+x表示成如下形式：</p> \[1 + x = 2^k (1 + f), \quad \text{where} \quad \frac{1}{\sqrt{2}} &lt; 1+f &lt; \sqrt{2}\] <p>这样的话 $\log(1+x)$ 的结果就等价于 $k \log 2 + \log(1+f)$ ，又因为 1+f 在1附近的范围内，所以我们可以使用Taylor展开，设 $s = \frac{f}{2+f}$，则</p> \[\begin{aligned} \log(1+f) &amp;= \log(1+s) - \log(1-s) \newline \log(1+s) &amp;= s - \frac{s^2}{2} + \frac{s^3}{3} - \frac{s^4}{4} + \dots \newline \log(1-s) &amp;= -s - \frac{s^2}{2} - \frac{s^3}{3} - \frac{s^4}{4} + \dots \newline \log(1+f) &amp;= 2s + \frac{2}{3} s^3 + \frac{2}{5} s^5 + \frac{2}{7} s^7 + \dots = 2s + s \cdot R(s) \end{aligned}\] <p>其中， $R(s) = \frac{2}{3} s^2 + \frac{2}{5} s^4 + \frac{2}{7} s^6 + \dots$ ，可以用<a href="https://en.wikipedia.org/wiki/Remez_algorithm">Remez算法</a>做近似估计： $R(s) \approx Lp_1 s^2 + Lp_2 s^4 + Lp_3 s^6 + Lp_4 s^8 + Lp_5 s^{10} + Lp_6 s^{12} + Lp_7 s^{14}$ ，为了把rounding error控制在 $2^{-58.45}$ 以下，可以估算出来<a href="https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/s_log1p.c#L92-L98">具体的Lp数值</a>。</p> <p>当然，如果觉得这个算法过于复杂，其实也可以简单利用 $ x \to 0 $ 时的Taylor展开做近似（参考<a href="https://www.johndcook.com/blog/cpp_log_one_plus_x/">博客</a>）。</p> \[\ln(1 + x) = \sum_{n=1}^{\infty} (-1)^{n+1} \frac{x^n}{n} = x - \frac{x^2}{2} + \frac{x^3}{3} - \frac{x^4}{4} + \frac{x^5}{5} - \cdots\] <p>pytorch针对MPS backend，在float32下的实现就是做前两项的近似（具体参考<a href="https://github.com/pytorch/pytorch/blob/68dfd44e50f59c53698a24985039a27351862963/c10/metal/special_math.h#L604">torch实现</a>）。我们也可以沿着这个思路写一个python版本。因为python里面所有的浮点数内部的表示都是double，double的eps value的量级大概在1e-16，所以如果仍然估计到 $x^2$，比较的数值大概在1e-6 （因为 $(1e-6)^3 &lt; 1e-16$ ）</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">math</span>

<span class="k">def</span> <span class="nf">log1p</span><span class="p">(</span><span class="n">x</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
    <span class="k">if</span> <span class="nf">abs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mf">1e-6</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">math</span><span class="p">.</span><span class="nf">log</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">x</span><span class="p">)</span>

    <span class="nf">return </span><span class="p">(</span><span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">+</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">*</span> <span class="n">x</span>
</code></pre></div></div> <p>了解了log1p的算法之后，我们很自然地就会有一个新的问题：既然log1p要比log本身更加精确，那我在计算的时候为什么不用<code class="language-plaintext highlighter-rouge">log1p(x - 1)</code>来代替<code class="language-plaintext highlighter-rouge">log(x)</code>？</p> <p>原因也很简单：log1p是在x比较小(x的绝对值小于1)的情况下存在数值稳定性意义，当而x比较小的时候，x - 1操作本身就会引入精度损失，log1p再提高精度损失，结果就是<code class="language-plaintext highlighter-rouge">log1p(x - 1)</code> 和 <code class="language-plaintext highlighter-rouge">log(x)</code>的结果一样。所以没有必要多此一举。</p> <p>因为log1p是在x比较小的情况下存在数值稳定意义，而概率本身就满足数值比较小的定义。考虑求一个二分类的entropy，假设概率为prob，那么另一类概率就是1-prob，此时可以用<code class="language-plaintext highlighter-rouge">log1p(-prob)</code>来实现一个更加数值稳定的版本。</p> <p><br/></p> <h4 id="expm1">expm1</h4> <p>expm1是log1p的逆函数，计算的是exp(x) - 1的值。它的数值稳定性问题主要体现在x的数值很小（接近0）的情况，此时这个数值会渐近地趋向于1+x。</p> <p>下面的例子很好地展示了expm1的实现和原始实现在接近0的数值下的稳定性差异：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">from</span> <span class="n">math</span> <span class="kn">import</span> <span class="n">exp</span><span class="p">,</span> <span class="n">expm1</span>

<span class="n">x</span> <span class="o">=</span> <span class="mf">1e-8</span>
<span class="nf">print</span><span class="p">(</span><span class="nf">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>   <span class="c1"># 9.99999993922529e-09
</span><span class="nf">print</span><span class="p">(</span><span class="nf">expm1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>     <span class="c1"># 1.0000000050000001e-08
</span></code></pre></div></div> <p>对照log1p，我们也从<a href="https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/s_expm1.c#L17">glibc实现</a>上简单讲一下这个函数的数值稳定写法的原理。</p> <p>和log1p类似，首先将x表示为下面的形式：</p> \[x = k \ln 2 + r, \quad \text{where} \quad |r| \leq 0.5 \ln 2\] <p>这样的话，我们有</p> \[e^x - 1 = \begin{cases} 2^k \cdot (e^r + 1) - 1, &amp; k &lt; -2 \text{ or } k &gt; 56 \\ 2^k \cdot (e^r - 1) + (2^k - 1), &amp; o.w. \end{cases}\] <p>因为r处在一个比较窄的范围内，所以我们可以用 $ x \to 0 $ 的Taylor展开直接解决战斗：</p> \[e^r - 1 = r + \frac{r^2}{2} + \frac{r^3}{6} + \dots\] <p>当然，glibc中的实现是使用下面的近似</p> \[\frac{r(e^r + 1)}{e^r - 1} = 2 + \frac{r^2}{6} - \frac{r^4}{360} + \dots = 2 + \frac{r^2}{6} \cdot R_1(r^2)\] <p>之后再通过Remez算法做近似估计，把rounding error控制在  $2^{-61}$ 之下，求出具体的<a href="https://github.com/lattera/glibc/blob/895ef79e04a953cac1493863bcae29ad85657ee1/sysdeps/ieee754/dbl-64/s_expm1.c#L127-L131">Q值</a>。</p> <p>当然，和log1p一样，我们也同样可以直接用Taylor展开做近似，参考<a href="https://www.johndcook.com/cpp_expm1.html">博客</a>，我们给出下面的python实现：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">math</span>

<span class="k">def</span> <span class="nf">expm1</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
    <span class="k">if</span> <span class="nf">abs</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">&lt;</span> <span class="mf">1e-5</span><span class="p">:</span>
        <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x</span> <span class="o">*</span> <span class="n">x</span>

    <span class="k">return</span> <span class="n">math</span><span class="p">.</span><span class="nf">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">-</span> <span class="mf">1.0</span>
</code></pre></div></div> <p><br/></p> <h4 id="softmax及其变体">softmax及其变体</h4> <p>softmax函数是数值稳定的经典case，几乎每一个搞深度学习的工程师/研究员都应该或多或少地知道这里面的数值稳定技巧，在<a href="https://www.deeplearningbook.org/">deep learning(花书)</a>的第四章中特意介绍了它的数值稳定技巧。softmax函数的定义是：</p> \[\mathrm{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}\] <p>如果直接使用 $x_i$ ，通常会导致 $\exp(x_i)$ 的数值过高而导致overflow，所以更加数值稳定的写法则是：先减去 $\max(x)$的数值，再做softmax操作。也就是：</p> \[\mathrm{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}\] <p>因为softmax的结果算出来是概率，所以通常为了求交叉熵，会有计算log softmax的过程，</p> <p>根据上面softmax的公式，我们有</p> \[\begin{aligned} \log (\mathrm{softmax}(x_i)) &amp;= \log \left( \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \right) \newline &amp;= \log \left( e^{x_i - \max(x)} \right) - \log \left( \sum_j e^{x_j - \max(x)} \right) \newline &amp;= x_i - \max(x) - \log \left( \sum_j e^{x_j - \max(x)} \right) \end{aligned}\] <p><a href="https://github.com/pytorch/pytorch/pull/21672/files">torch对于logsoftmax实现</a>就采用了这种写法，但是我们也可以引入LSE（log sum exp）来简化计算。LSE的定义是：</p> \[\mathrm{LSE}(x_1, \dots, x_n) = \log \left( \sum_{i=1}^{n} e^{x_i} \right)\] <p>当然，类比softmax，LSE也有一个数值稳定的写法，也就是</p> \[\mathrm{LSE}(x_1, \dots, x_n) = \max(x) + \log \left( \sum_{i=1}^{n} e^{x_i - \max(x)} \right)\] <p>对log-softmax应用LSE trick，就可以把表达式简化为：</p> \[\log \mathrm{softmax}(x_i) = x_i - \mathrm{LSE}(x_1, \dots, x_n)\] <p>LSE trick的使用非常广泛，涉及到softmax的简化计算时，通常会使用这个方法。一个典型的例子就是<a href="https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/csrc/flash_attn/src/flash_fwd_kernel.h#L1183">flash attention v2里面</a>的LSE trick（后续计算使用exp抵消log），省掉了一个对角矩阵的乘法运算。Pytorch中 <a href="https://github.com/pytorch/pytorch/blob/781d28e2655f88ae2fef827ed110f22ed553a0ab/aten/src/ATen/native/cuda/LossCTC.cu#L164">CTC loss</a>、<a href="https://github.com/alykhantejani/pytorch/blob/f7c6ba67afdd6885e1efb96540906364aa506f9c/torch/lib/THNN/generic/LogSigmoid.c#L13">log-sigmoid</a>里面也都用了这个trick。而且更有意思的事情是，LSE的导函数就是softmax。</p> <p>到这里，我们再深入看一下交叉熵损失，<a href="https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html">torch的doc</a>里面有提到，交叉熵损失等价于对输入做log-softmax，然后再使用nll (negative log likelihood) loss，这个实现方式比起直接套用交叉熵损失的写法，也会更加数值稳定一些。</p> <p><br/></p> <h4 id="softplus">softplus</h4> <p>softplus主要用于需要<strong>平滑非负输出</strong>的场景，有时候会作为relu的一个平滑替代，或者解决dying relu问题（虽然在deep learning book中<a href="https://www.deeplearningbook.org/contents/mlp.html">6.3.3的最后部分</a>明确提到了虽然softplus比relu处处可导、饱和程度低，但实际上并没有比relu好），这个激活函数在<a href="https://gist.github.com/hunter-heidenreich/9512636394a23721452046039dd52d90#file-vae-py-L30">VAE中比较常见一些</a>。</p> <p>softplus的数学表达式为：</p> \[\mathrm{softplus}(x) = \log(1 + e^x) = \mathrm{log1p}(e^x)\] <p>softplus的数值不稳定性来自于exp，因为指数增长地非常快，很容易overflow成为inf，所以诸如<a href="https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html">pytorch里面</a>会设置一个阈值，当满足条件的时候会将softplus(x)的值设为x，比如在python中如果执行<code class="language-plaintext highlighter-rouge">math.log(1 + math.exp(34))</code>返回的结果就是34。 所以很自然地，softplus的数值稳定的写法是进行分段。</p> \[\text{softplus}(x) =\begin{cases} x, &amp; x &gt; 20 \quad (\text{避免overflow}) \\\log(1 + e^x), &amp; -20 \leq x \leq 20 \\e^x, &amp; x &lt; -20 \quad (\text{避免underflow})\end{cases}\] <p>但是实际上还有另外一个写法，<a href="https://github.com/MegEngine/MegEngine/blob/47952c075d868665e1116214bea760d786144081/imperative/python/megengine/xla/rules/elemwise.py#L404">megengine</a>和<a href="https://github.com/FlagOpen/FlagGems/blob/93713395eb1e64f13e8ad021b2026e47d7f9d64f/src/flag_gems/ops/log_sigmoid.py#L12">FlagGems</a>里面都采用了这样的写法，表达式上也优雅了许多，不再需要手动设置阈值：</p> \[\mathrm{softplus}(x) = \log(1 + e^{-|x|}) + \max(0, x) = \mathrm{log1p}(e^{-|x|}) + \max(0, x)\] <p>有了数值稳定版本的softplus，进而就可以衍生出来一些使用softplus的数值稳定写法，下面我来介绍一些比较常见的，这些方法通常和sigmoid相关。</p> <p>首先第一个就是log sigmoid，直接就可以用-softplus(-x)做替代。</p> \[\log\text{-}\mathrm{sigmoid}(x) = \log \sigma(x) = \log \left( \frac{1}{1 + e^{-x}} \right) = -\log(1 + e^{-x}) = -\mathrm{softplus}(-x)\] <p>其次是<a href="https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/distributions/transforms.py#L609">sigmoid transform</a>，用来做概率分布的变换。当涉及到两个概率分布的时候，通常会牵扯到ladj（log abs det jacobian）的概念，而在<a href="https://github.com/pytorch/pytorch/pull/19802">torch pr 19802</a>里面有提到过sigmoid transform的数值稳定问题。</p> <p>假设 $y = \sigma(x) = \frac{1}{1 + e^{-x}}$ 表示映射关系，那么逆变换则为 $x = \mathrm{logit}(y) = \log(y) - \log(1-y)$，ladj对应的表达式为</p> \[\begin{aligned} \log |\det J| &amp;= \log(\frac{d}{dx} \sigma(x)) \newline &amp;= \log(y) + \log(1 - y) = -\log(\frac{1}{y} + \frac{1}{1-y}), \quad \text{从}y\text{的视角} \newline &amp;= \log\sigma(x) + \log(1 - \sigma(x)) = -\mathrm{softplus}(-x) -\mathrm{softplus}(x), \quad \text{从}x\text{的视角} \end{aligned}\] <p>显然，softplus的写法（从x视角）要比log的写法（从y视角）更加数值稳定。因为当y接近1或者0的时候（考虑到sigmoid的饱和区间，还是很容易接近0或者1的），显然倒数的写法会更不稳定，而softplus的版本在更大的范围上性质良好。</p> <p>如果仔细看softplus函数的表达式，也可以将其表示为log1pexp函数，对应的也有log1mexp函数，也就是 $\log(1 - e^x)$ 。虽然有<a href="https://github.com/pytorch/pytorch/issues/39242">issue</a>提到这个函数，但是在pytorch里面并没有这个实现，这里有一个<a href="https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf">note</a>介绍log1mexp的数值稳定写法，里面也涉及到softplus的数值稳定写法，感兴趣的可以自己去看看，我们此处不再展开。</p> <p>softplus是sigmoid的好伙伴，就如同LSE和softmax之间的关系一样，softplus的导函数恰好也是sigmoid。</p> <p><br/></p> <h3 id="后记">后记</h3> <p>数值稳定性是我在过去的一段时间里比较好奇的一个问题，在之前<a href="https://fatescript.github.io/blog/2024/LLM-RAG/">讨论rag的博客</a>中，我曾经抛出过类似的问题给大模型，但是一直没有得到比较满意的回答。在经过了近一年的积累和检索一些issue/pr，我自己也能获得这个问题的相对满意的答案了。</p> <p>也许未来有一天，<a href="https://openai.com/index/introducing-deep-research/">deep research</a>会逐渐取代我的工作方式，但我衷心希望，AI不会剥夺人类探索和创造的乐趣。这次对数值稳定性的探索，让我想起小时候拆解家里老式机械钟表的过程，也许会迷失在齿轮的迷宫里，但在弹簧突然弹开的“咔嗒”声中，我知道我搞定了一切。</p> <p><strong>一种纯粹和天赐的快乐。</strong></p> <h3 id="reference">Reference</h3> <p>[1] <a href="https://arxiv.org/pdf/2202.03493">DeepStability: A Study of Unstable Numerical Methods and Their Solutions in Deep Learning</a><br/> [2] <a href="https://math.stackexchange.com/questions/907327/accurate-floating-point-linear-interpolation/1798323#1798323">stack exchange上对于lerp实现的讨论</a><br/> [3] <a href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm">welford算法</a><br/> [4] <a href="https://www.johndcook.com/blog/cpp_log_one_plus_x/">C++实现的简单版本log1p</a><br/> [5] <a href="https://www.johndcook.com/cpp_expm1.html">C++实现的简单版本expm1</a><br/> [6] <a href="https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/">The Log-Sum-Exp Trick</a><br/> [7] <a href="https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf">介绍log1mexp 数值稳定的note</a></p>]]></content><author><name></name></author><category term="deep-learning"/><category term="engineering"/><category term="code"/><summary type="html"><![CDATA[数值稳定：管理误差的艺术]]></summary></entry><entry><title type="html">2024年的记忆碎片</title><link href="https://fatescript.github.io/blog/2025/mem-in-2024/" rel="alternate" type="text/html" title="2024年的记忆碎片"/><published>2025-01-10T15:59:00+00:00</published><updated>2025-01-10T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2025/mem-in-2024</id><content type="html" xml:base="https://fatescript.github.io/blog/2025/mem-in-2024/"><![CDATA[<p>今年写的文章很少，但其实我的表达欲并没有减少，核心原因是行业的发展太过迅速，能留给自己的时间越来越少。所以这次把自己平时想要表达的东西，汇总成一个年末的blog，聊的东西会很多很杂，不像以前纯技术向的blog，本篇不涉及任何技术和细节，只聊头脑里的想法，权当作是对想法的“开源”。</p> <h5 id="家的实感"><strong>家的实感</strong></h5> <p>要说今年生活上最大的变化就是装修新家，虽然买房这件事已经过了很久，北京的房价在这一年也是逐渐下跌，但是每天结束了忙碌的工作，能够回到自己温馨的家里，还是会有一种安心的感觉。然而，装修毕竟是一项工程，面临的是对于现实的妥协和权衡。很感激老婆在这个过程中辛苦地和各方沟通，让我能把精力集中在繁忙的工作中；很感激xxr同学在装修过程中提供的物料资助，让本不富裕的小家庭在没有雪上加霜；很感激爸妈能够千里迢迢来北京帮忙收拾、砍价，给不懂街头智慧的小两口狠狠上了好几课。</p> <p>最后放一张我最喜欢的手办柜照片吧，也许将来有了小孩子，这些东西又会消失，但至少现在，他们因为我而聚集在这片空间里。</p> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/figure-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/figure-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/figure-1400.webp"/> <img src="/assets/blog/figure.jpg" class="img-fluid rounded z-depth-1" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>​ <strong>家于我而言是一种感觉和记忆，而房子只是承载这一切的钢筋和水泥</strong>。</p> <h5 id="写有信息增量的文字"><strong>写有信息增量的文字</strong></h5> <p>过去的一年毫无疑问是LLM（大语言模型）突飞猛进的一年，这也是我的blog没怎么更新的原因。一方面是因为工作的忙碌，抽出来时间写东西就已经很难了；另一方面，就是每当我静下心来开始写点东西，很快就会有类似的文章出来，次数多了，也逐渐总结了一些写blog的方法论：</p> <ol> <li>不写教程类的blog。虽然这个想法有点傲慢，但是我觉得一个仔细阅读文档就能搞清楚的内容，在大模型时代，存在感会越来越弱。blog应该做深度的串联和挖掘，而不是简单的陈述。</li> <li>不写自己没有实践过的内容。这个很好理解，没有足够的实践，很难谈上真正的理解。</li> <li>不写在一段时间内别人也会写的内容（除非自己能看的相对更深入）。如果对于其他人的观感来说，A做和B做没有差别，那么这样的blog就是纯粹比拼速度，显然不是我擅长的。</li> </ol> <p>虽然看起来方法很多，但是本质上只有一个原则：<strong>提供信息增量</strong>。在这个原则下，去年还是留下了一篇关于<a href="https://fatescript.github.io/blog/2024/LLM-RAG/">RAG</a>的博客，虽然只有一篇，但写的过程还是花费了不少精力的。对我来说，<strong>花费心力去把一件事情做好做扎实，要远胜泛泛地完成很多事情</strong>。</p> <h5 id="把手弄脏"><strong>把手弄脏</strong></h5> <p>今年技术的发展完全可以用“日新月异”来形容，而作为一个不“把手弄脏”（get your hands dirty）就不算了解的人，我也在不断地陷入兴奋和放下兴奋的情绪循环中。作为这种循环的副产物，<a href="https://github.com/FateScript/experiments">experiments</a> 也很自然地产生了。这个repo包含了一些我在好奇和探索中留下的practice code，如果你和我本质上是一类的人，翻看它你也许会会心一笑，或者下面这些碎片也许会唤起你自己的一些记忆</p> <ol> <li>因为好奇LLM训练的各种Parallel技巧，我拿python对照<a href="https://github.com/facebookincubator/gloo/">gloo</a>简单写了一个<a href="https://github.com/FateScript/experiments/blob/main/se/mpi/mpi.py">mpi</a>的接口。实现的过程中，先是更深刻地理解了ring all-reduce为什么等价于reduce-scatter + ring all-gather，接着惊叹于<a href="https://www.inf.ed.ac.uk/teaching/courses/ppls/BarrierPaper.pdf">dissemination barrier</a>算法设计的精妙。到最后，用这些mpi算子，写了<a href="https://arxiv.org/abs/2309.14509">DeepSpeed Ulysses</a>的numpy版本，当写的test case第一次通过的时候，仿佛又回到了大一时第一次看到“hello world”打印在屏幕上的时光。</li> <li>一次和朋友出去爬山，闲聊的时候他说了一句“其实并行排序算法也很有意思”，当时没放在心上，但在某一次坐地铁的时候，突然想起来这句话，于是搜到了nv的一篇<a href="https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting">教程</a>，后来花了一个星期在通勤路上拿着ipad琢磨<a href="https://github.com/FateScript/experiments/blob/main/algo/bitonic.py">bitonic sort</a>为什么能work，又应该怎么写。完成实现的那一刻，像是有阳光照在一个隐秘的角落，而角落里本该枯萎的植物发了芽。</li> </ol> <p>也许大多数人看到上面的记忆碎片会觉得很无趣，但我相信，一定会有人理解这种快乐，虽然不像Aha moment那么令人欢欣鼓舞，但也足够让人喜悦。<strong>打开引擎盖，重新造齿轮，然后看着新的齿轮也可以正确地运转，是令人快乐的！</strong></p> <h5 id="一剂解药"><strong>一剂解药</strong></h5> <p>这几年，其实有越来越多的人认识到了程序员工作的本质：技术工人。显然，这份工作并不掌握任何的生产资料，也不像体制内的大部分工作，能依附于一个巨大的造血系统。所以行业内的绝大多数人，本质上在通过出售自己的脑力和体力维持生活。即使转向了管理，只是延长了自己的职业生涯，本质属性仍未改变，所以很多人会焦虑，我也曾经是其中的一员。</p> <p>不焦虑不代表问题已经被解决了，而是这个问题没有那么重要了。也许下面这段话听起来有些鸡汤，但我认识到：<strong>在把事情做好之前，过早地担忧结果的好坏是没有意义的。在迎来结局之前，只能打好自己手里的牌。</strong>就像没有人会因为生命终将逝去，就放弃认真地活着。或许一个更为豁达的态度是：<strong>纵使生命终究虚无，我们仍可以在这片荒芜之地，尝试种出自己的参天大树。</strong>游戏科学CEO冯骥在B站的跨年晚会上留下了这么一首诗：“喜见料峭风，乐听奔雷水。天命不由己，未竟亦可为。”我想，内核是一样的。</p> <p>越过山丘，或许是另一个山丘。但是爬山的过程，就是对登山者最好的奖励。对我来说，找到好玩有趣的事情，剥开它的外壳，重新造一遍轮子，就像是在山上留下了自己的足迹。人生苦短，唯有这些探索与创造，见证了我们如何塑造自己。</p> <p>人们常说焦虑是对自由的眩晕，而好奇和长期是我开给自己的解药。</p> <h5 id="苦涩的教训"><strong>苦涩的教训</strong></h5> <p>既然说到了技术工人，那就来聊一聊工作。实话说，这一年来的工作体验实在说不上好，频繁地rush和救火，毫无规划的deadline，令我不禁怀疑身在何方。不过，这些糟糕的体验也不全是坏事，我也能比上份工作，更能看清更多的东西。除了典型的“世界是个草台班子”这种论调，我想聊点别的。</p> <p>首先从最高层看，“<strong>对于大多数公司来说，职场是政治生态的延伸</strong>”。决策层有着最高的话语权，把控整个方向，所以为了争夺话语权，政治斗争是从来不会少的，为了显示自己的重要性，总是有各种办法争取更多的人力（毕竟人力是最重要的资源之一）掌握资源的人自然掌握权力。裁员可能仅仅是权力的再分配，是历史的进程，和个人的努力与否无关。</p> <p>接着从中层看，其实他们的工作目标<strong>从来不是让公司的客户满意，或者扎实地完成好某个事情，而在于使那些控制他们加薪和晋升的人满意</strong>。虽然我不想承认，但这也许是职场上”苦涩的教训”。经济学领域有一个著名的Goodhart’s law，讲的是“一项指标一旦变成了目标，它将不再是个好指标”，这个”make boss happy”的“职场游戏”，也只是这一法则的另一种注脚。</p> <p>最后，从我自己的视角来看。其实工作的过程肯定是伴随着痛苦的，而免遭痛苦的方式是有很多的。但是在这里，我想放一句在《<a href="https://book.douban.com/subject/10555509/">看不见的城市</a>》摘录的话来与君共勉：免遭痛苦的办法有两种，对于许多人，第一种很容易：接受地狱，成为它的一部分，直至感觉不到它的存在；第二种有风险，要求持久的警惕和学习：<strong>在地狱里寻找非地狱的人和物，学会辨别他们，使他们存在下去，赋予他们空间</strong>。</p> <h5 id="良质和时间幻觉"><strong>“良质”和“时间幻觉”</strong></h5> <p>今年看书的时间被大大地压缩了，工作的疲劳除了消耗人的体力之外，也会很大程度上削减人的“审美体力”，很可能在工作已经很疲惫的时候，回家只想刷刷手机/短视频，玩一些不费脑子的小游戏。虽然我也有过这样的日子，不过今年还是很庆幸完整看完了一部分书籍。而我最喜欢的，是《<a href="https://book.douban.com/subject/6811366/">禅与摩托车维修艺术</a>》和《<a href="https://book.douban.com/subject/26980487/">悉达多</a>》。虽然看起来风马牛不相及，但是这两本书要表达的内核其实是完全一致的。</p> <p>《禅与摩托车维修艺术》的核心是“良质”，但是作者在书中拒绝给出任何关于良质的定义，所以下面聊的，是我的理解：从英文视角来看，良即good/high，质即quality。但是“良质”的含义其实并非简单的“高质量”这么简单，它其实在<strong>表达一种多数人心中关于“好”的准则和判断，它既是主观的，但也是客观的</strong>。</p> <p>举一个典型的例子，同样是中文的电影，也许有某个人会认为《<a href="https://movie.douban.com/subject/26322774/">逐梦演艺圈</a>》比《<a href="https://movie.douban.com/subject/1291546/">霸王别姬</a>》好看，这个是很主观的；但是如果你找一千个人打分然后取平均，《霸王别姬》一定会比《逐梦演艺圈》高，这个是很客观的。于是，我们就可以说《霸王别姬》比《逐梦演艺圈》更“良质”，但是“良质”并不是一种衡量尺度，而是你判断A比B更“良质”的内心感觉，阳明心学里面也管这个叫良知（好的认知）。多数人趋同的主观乃是一种客观，所以良质即是主观的，但也是客观的。</p> <p>如果你在参与某项工作，要判断它是否良质也很简单：如果你很自豪地告诉别人某项工作是你参与的，那这个工作就是良质的，否则就是“普质”的。《<a href="https://book.douban.com/subject/25956450/">人件</a>》里面有一个说法：工程师应当拥有否决发布未成熟产品的权利，其实就是在反对普质，追求良质。我相信大部分软件工程师会无比同意我这句话：花费时间精力去维护连自己都不关心的程序，实在是一种煎熬。</p> <p>《悉达多》则揭示了一个更为深邃的观念：”时间是一种幻觉，而所有的分界也是一种幻象”。我们常说less is more, slow is fast, empty is full. 这些看似矛盾的观照，实则指向统一。<strong>二元对立是人心刻意为之的划分</strong>。正如河流不因我们划分上游下游而有所断裂，时间的本质亦是浑然一体。在永恒的时间长河中，我们执着地要用年华、季节、今朝、明昔来丈量它、切割它，却忘了这些概念本就是人为的虚设。时间恰如永不停息的流水，不始于过去，不终于未来，而是永恒、完整地存在于此时此刻。所谓时间的概念，更像是人心投射的幻影。</p> <h5 id="向前看别回头"><strong>向前看，别回头</strong></h5> <p>整个季节将它结成了琥珀<br/> 块状的流淌，具体的光芒<br/> 在它背后是些遥远的事物<br/>            ——《漫长的季节》</p> <p>走在时代的分叉路口，没有人知道前方是坦途亦或猛兽。也许会有后来的登山者发现了我们留下的路标，笑笑说这是些很遥远的事物。回头去看，也许只是浮尘微粒在风口之下的聚散，又或者时间会将我们的想法凝成琥珀，但这些都不重要。</p> <p><strong>重要的是继续往前走，哪怕前方一无所有，哪怕前方无人等候</strong>。</p>]]></content><author><name></name></author><category term="reflection"/><category term="reflection"/><summary type="html"><![CDATA[往前走，别回头]]></summary></entry><entry><title type="html">当我谈RAG时我谈些什么</title><link href="https://fatescript.github.io/blog/2024/LLM-RAG/" rel="alternate" type="text/html" title="当我谈RAG时我谈些什么"/><published>2024-04-19T15:59:00+00:00</published><updated>2024-04-19T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2024/LLM-RAG</id><content type="html" xml:base="https://fatescript.github.io/blog/2024/LLM-RAG/"><![CDATA[<h4 id="tl-dr">TL; DR</h4> <ul> <li>大模型同时拥有多种完全矛盾的知识，也有自己本身的prior。</li> <li>关键token的采样结果非常影响模型的生成效果。</li> <li>除去直接提供答案，搜索本身是通过帮助模型”回想场景”或引导模型的prior来增强生成效果。</li> <li>分享一下自己写的token level<a href="https://github.com/FateScript/token_visualizer">可视化工具</a><sup>[1]</sup>，希望这个工具能够有所帮助。</li> </ul> <h4 id="intro">Intro</h4> <p>前一段时间在做一些RAG（retrieval augmented generation）相关的事情，如果你不了解RAG的流程，那么可以理解为：每一个query会经过意图模型判断“是否要进行搜索”，“搜索词是什么”等，经过搜索引擎检索文档提供材料后，辅助模型进行回答（本质上是一种比较灵活的知识注入）。具体的流程图示可以参考下图：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_pipeline-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_pipeline-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_pipeline-1400.webp"/> <img src="/assets/blog/rag_pipeline.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">RAG pipeline<sup>[2]</sup></figcaption> </figure> ​ </div> <p>RAG领域当然也有一些paper，但是总体上来看，普遍是建立在workflow上的改进，比如材料排序和过滤的技巧、根据模型置信度采取不同的prompt模板<sup><a href="https://arxiv.org/pdf/2401.15884.pdf">[3]</a></sup>（提高反事实能力）、每一轮结束后引入反思机制提高解决多跳问题<sup><a href="https://arxiv.org/pdf/2310.11511.pdf">[4]</a></sup>（本质上是一个agent）…这些文章更偏向tricks，虽然仍然有可以学习的insight，但这些文章没有解决我一直好奇的一个问题：<strong>本质上，搜索材料是如何帮助模型完成回答的</strong>？</p> <h4 id="rag的两种场景">RAG的两种场景</h4> <p>这个问题肯定要分成两个场景去看：首先，考虑模型本身就没有的知识，联网搜索只是单纯地提供材料，材料说啥就是啥，材料有错那模型回答同样有错，这点是毋庸置疑的，因为知识是存粹注入的。这个场景的一些典型的例子： “周杰伦为啥在卖煎饼？”、“电影《周处除三害》讲了什么故事”。模型在没有这部分知识的情况下，不可能回答正确，此时的RAG = 搜索引擎提供参考 + 模型有啥说啥。</p> <p>不过，很多时候我们的疑问在于第二种场景，也是这篇博客要讨论的内容：<strong>大模型本身已经学习了某些知识，但是因为数据本身的长尾性，学习的效果并不好</strong>。</p> <p>一个典型的例子就是reverse curse（模型知道A是B，但是并不知道B是A）<sup><a href="https://arxiv.org/pdf/2309.12288.pdf">[5]</a></sup>。我之前写过一个<a href="https://fatescript.github.io/blog/2023/LLM-markov-chain/">blog</a>介绍过markov chain视角下的LLM，沿着那个视角去进一步思考reverse curse：基于next-token prediction的做法，训练预料“A是B”能够帮助建模 A -&gt; B的状态转移，但是无法建模 B -&gt; A的状态转移概率。所以后续看到的一些文章都是通过类似permute<sup><a href="https://arxiv.org/pdf/2403.00758.pdf">[6]</a></sup>/reverse<sup><a href="https://arxiv.org/pdf/2403.13799.pdf">[7]</a></sup>的数据增强形式来帮助模型进行建模。但这两个做法还是停留在在“头疼医头，脚疼医脚”的范畴内，没有从根本上解决本质问题：在 P函数表示等价的语意时，建模 \(P(A \mid B)\) 不等价 \(P(B \mid A)\)。甚至 P函数表示等价语意也并不本质，举个例子，当A, B具有夫妻关系，\(P( X \mid Y )\) 定义为X是Y的子女的概率时，模型建模 \(P(C \mid A)\) 也不等价于建模 \(P(C \mid B)\)。</p> <p>回到场景本身，RAG本身是比较容易解决大部分长尾知识的，虽然这个能力来自于搜索引擎的帮助，但是我们的问题看起来更清楚了一些：<strong>长尾知识是模型在预训练阶段见过的了，那么搜索材料是如何帮助模型正确回答问题的呢</strong>？</p> <h4 id="蝴蝶效应之采样偏差">蝴蝶效应之采样偏差</h4> <p>模型的生成过程是概率的采样过程，不论模型的回答是正确还是错误，都有可能是采样偏差导致的问题。所以为了确定token level的概率，我写了一个token level的<a href="https://github.com/FateScript/token_visualizer">visualizer</a>，方便查看到底哪些token出现了偏差，偏差的概率又是多少。为了方便理解采样偏差，下面给了一个例子：</p> <p>原始的prompt是”背诵《赤壁赋》”，下面是截取的一部分可视化的结果（越红色部分表示token概率越低，越绿表明表明token概率越高）：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_sample-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_sample-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_sample-1400.webp"/> <img src="/assets/blog/rag_sample.jpeg" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">生成结果可视化</figcaption> </figure> ​ </div> <p>从回答内容来说，“月明星稀，乌鹊南飞”后面的回答是错误的，应该是“此非曹孟德之诗乎”，变成了“绕树三匝，何枝可依？”。但是如果看“绕”这个位置的token对应的概率，可以看到“此”的概率是88.8%，“绕”的概率只有9.3%，虽然模型此处的回答是错误的，但只能说是“运气不佳”采样到了错误的token，如果采样到了正确的token，模型其实是可以回答正确的。</p> <p>既然讲到了采样概率，我在这里也稍稍提出来一个猜想，稍稍地抛砖引玉一下，欢迎更有insight的人证明/证伪：<strong>RLHF阶段会使得一部分token（比如格式相关的token）的概率分布变得更sharp来达到与人类align的目的，本质上是一种对于模型的hotfix</strong>。之所以会有这种猜测是之前可视化过一些做了RLHF的模型，直观感受就是概率分布通常比较sharp。不过沿着这个想法继续下去，即使模型和人类align之后，其实本质上仍然保留着原始能力，只是看起来某些“危险发言”因为采样概率降低而看起来“不见”了。之前有个<a href="https://www.lesswrong.com/posts/qmQFHCgCyEEjuy5a7/lora-fine-tuning-efficiently-undoes-safety-training-from">blog</a>介绍了通过LoRA回撤alignment的效果，从这个角度来思考，完全是合理的。</p> <h4 id="rag的三轮测试">RAG的三轮测试</h4> <p>重新回到之前的问题上，搜索（或者任何形式的知识注入）到底是如何增强模型进行回答的呢？</p> <p>为了回答这个问题，我们首先需要设计一些case来进行检验，众所周知，因为GPT本身会更新训练语料，所以新的版本通常会比旧版本更新一些信息。而我们的case最好应当在更新信息时间范围之内，通过<a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">官方文档</a>来看，最好控制在Sep 2021到Apr 2023之间。除此之外，这个信息应该尽量广为人知一些，类似太阳系从九大行星更新到八大行星这种信息。</p> <p>第一个映入脑海的就是周杰伦有了三胎的消息，在google上搜索了一下，这个时间是在2022年5月。考虑到reverse curse，选择了“昆凌的孩子是谁（<code class="language-plaintext highlighter-rouge">Who are Hannah Quinlivan's children?</code>）”这个问题。</p> <p>注意，后文的几轮测试只是为了帮助建立直觉和认知，并非严格论证，仅仅是“管中窥豹”。</p> <p>UPDATE：在准备发布这个blog的前几天，我在arxiv翻到了一篇<a href="https://arxiv.org/pdf/2404.10198.pdf">文章</a><sup>[11]</sup>对本文的一部分观点做了定量分析，感兴趣的也可以去读一下。</p> <h5 id="第一轮基础测试">第一轮：基础测试</h5> <p>本章节中后续的测试会给出使用的pormpt和对应的可视化结果，所有的结果是在temperature设置成0.3的基础上得出。</p> <p>针对Prompt：</p> <blockquote> <p>Answer the following question: Who are Hannah Quinlivan’s children?</p> </blockquote> <p>下面是gpt-4（知识更新到2021年9月）的回答可视化：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_1_1-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_1_1-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_1_1-1400.webp"/> <img src="/assets/blog/rag_expr_1_1.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">GPT-4回答</figcaption> </figure> ​ </div> <p>可以看到GPT认为昆凌有两个孩子，并且这件事情的概率几乎为100%，这个知识和当时的现实是高度匹配的。</p> <p>接着去问gpt-4-1106-preview同样的问题，这个模型的知识已经更新到2023年4月份了，按说回答应该有很大概率认为是有三个孩子，然而可视化的结果是下面这样：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_1_2-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_1_2-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_1_2-1400.webp"/> <img src="/assets/blog/rag_expr_1_2.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">GPT-4-1106-preview回答</figcaption> </figure> ​ </div> <p>从回答来看，模型仍然认为昆凌有两个孩子，但是<code class="language-plaintext highlighter-rouge">three</code>的概率也有29% ，也就是在孩子的数量这件事情上，模型的知识只被修正了一部分。</p> <p>如果直接问周杰伦的孩子是谁呢？(后续问题都是使用gpt-4-1106-preview的结果)</p> <blockquote> <p>Answer the following question: Who are Jay Chou’s children?</p> </blockquote> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_1_3-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_1_3-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_1_3-1400.webp"/> <img src="/assets/blog/rag_expr_1_3.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>可以看到有两个孩子的概率直接到了接近100%，通过第二个和第三个回答的对比来看，gpt-4-1106-preview更容易认为周杰伦有2个孩子，结合reverse curse的paper里面表达的内容，不难看出来：<strong>模型的建模是非对称的</strong>。通过对概率的可视化，现在你应该对这种非对称性有了更深的认识了。</p> <p>根据前面两个不同的模型的回答，我们基本上能得到下面的结论：</p> <ul> <li><strong>模型本身同时拥有多种完全矛盾的知识</strong>。在这个例子里面就是昆凌有两个/三个孩子，很多时候回答的正确与错误仅仅是取决于采样本身，采样到了<code class="language-plaintext highlighter-rouge">three</code>还是<code class="language-plaintext highlighter-rouge">two</code>就足以决定最终回答的正确与否。在关键token上，就是“一招不慎，满盘皆输”。</li> <li>矛盾的知识应该来自于训练数据的矛盾，极大概率是预训练数据里面同时存在“两个孩子”和“三个孩子”的数据。这种<strong>因为新数据出现导致旧数据失效的问题大概率没有被考虑到</strong>，OpenAI对于数据管理也没有做的那么面面俱到。</li> </ul> <h5 id="第二轮给点hint">第二轮：给点hint</h5> <p>前一个小节中我们主要考虑依靠模型本身的能力进行回答的场景，这一个小节我们主要考虑给定一些信息对于原始问题的影响。</p> <p>考虑到OpenAI大概率会使用维基百科的数据做预训练，于是我对应找到了周杰伦的<a href="https://en.wikipedia.org/wiki/Jay_Chou">维基百科</a>，回滚到了更新周杰伦有三个孩子的历史版本，找到了附近的一句话：<code class="language-plaintext highlighter-rouge">In November 2014, Chou confirmed his relationship with model Hannah Quinlivan.</code> 对于回答“Who are Hannah Quinlivan’s children”这个问题来说，这句话除了提供了周杰伦是昆凌老公以及结婚时间之外，并没有额外的信息量，而我们上一轮的测试中，即使直接问”Who are Jay Chou’s children“这个问题，模型也并不能回答正确。所以可以确认这句话基本没有提供格外的信息。</p> <p>此时prompt变成：</p> <blockquote> <p>Here is some text from wiki:<br/> ```<br/> In November 2014, Chou confirmed his relationship with model Hannah Quinlivan.<br/> ```<br/> Answer the following question: Who are Hannah Quinlivan’s children?</p> </blockquote> <p>但是模型的回答出乎我的预料：模型认为有3个孩子，概率也提升到97%：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_2_1-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_2_1-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_2_1-1400.webp"/> <img src="/assets/blog/rag_expr_2_1.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>为了排除本身话术（也就是“Here is some text from wiki”）的影响，又加入了一个完全无用的信息进行测试：</p> <blockquote> <p>Here is some text from wiki:<br/> ```<br/> 1 + 1 = 2<br/> ```<br/> Answer the following question: Who are Hannah Quinlivan’s children?</p> </blockquote> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_2_2-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_2_2-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_2_2-1400.webp"/> <img src="/assets/blog/rag_expr_2_2.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>此时模型完全不能回答正确，且回答有两个孩子的概率反倒是提高到了98%，这说明提供有一定相关性的准确信息对于回答是有帮助的，而无关的信息可能有害的。</p> <p>为了和上一轮中“Who are Jay Chou’s children?”的问题做对照，这一轮也把问题中的昆凌换成周杰伦，同时提供完全相同的材料：</p> <blockquote> <p>Here is some text from wiki:<br/> ```<br/> In November 2014, Chou confirmed his relationship with model Hannah Quinlivan.<br/> ```<br/> Answer the following question: Who are Jay Chou’s children?</p> </blockquote> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_2_3-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_2_3-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_2_3-1400.webp"/> <img src="/assets/blog/rag_expr_2_3.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>有趣的是，尽管是同样提供了没有什么信息量的数据，模型回答错误的概率（也就是认为是两个孩子的概率）已经从非常置信（第一轮接近100%的概率）变成没那么置信（本轮的69.5%）了，这也就表明，从维基百科精心选取的这段看似没有太多信息量的内容，确实能够让模型回答变得更加正确。</p> <p>对于模型来说，在不提供提供正确答案，只有context的情况下，也能正确回答问题（类似markov视角下的CoT）。在这种场景里来看，搜索返回的材料更像是帮助模型”回想”正确答案。这就像是玩“听前奏猜歌名”的游戏，告诉你“你在电影里听到过”也能很大程度上提升回答准确率一样。</p> <h5 id="第三轮直接提供正确错误答案">第三轮：直接提供正确/错误答案</h5> <p>前两轮中，一轮是考察baseline，一轮是考察没有问题答案的hint类型，这一轮中，我们就考察直接提供答案的场景。毕竟大部分情况下的RAG都是伴随着问题的答案的。</p> <p>首先来看直接提供正确答案的case：</p> <blockquote> <p>Here is some text from wiki:<br/> ```<br/> Hannah Quinlivan has three children.<br/> ```<br/> Answer the following question: Who are Jay Chou’s children?</p> </blockquote> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_3_1-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_3_1-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_3_1-1400.webp"/> <img src="/assets/blog/rag_expr_3_1.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>不难看出来，在给出来正确答案的情况下，模型几乎百分百可以确认有三个孩子。同时，因为没有直接提供三个孩子的名字，根据模型回答来看模型本身是拥有这部分知识的。</p> <p>接着我们来看提供错误回答的case：直接在材料中说有五个孩子。</p> <blockquote> <p>Here is some text from wiki:<br/> ```<br/> Hannah Quinlivan has five children.<br/> ```<br/> Answer the following question: Who are Jay Chou’s children?</p> </blockquote> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/rag_expr_3_2-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/rag_expr_3_2-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/rag_expr_3_2-1400.webp"/> <img src="/assets/blog/rag_expr_3_2.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p>这个回答是最有趣的一个，模型认识到这个材料是反事实的，并且根据自己现有的知识直接回答有两个孩子。而且看起来模型除了<code class="language-plaintext highlighter-rouge">two</code> 这个token之外，<code class="language-plaintext highlighter-rouge">three</code> 的概率排在第二位。虽然概率不足1%，但是此处没有<code class="language-plaintext highlighter-rouge">four</code>和<code class="language-plaintext highlighter-rouge">five</code>的token，也可以表明模型并没有建模错误的知识，这也是反事实能力的来源。</p> <p>在反事实的样例中也可以看到，模型在默认状态下也仍然认为周杰伦只有两个孩子，这也强化了我们在第一轮的认知：虽然模型的训练语料已经更新到2023年，但是<strong>反映到模型来说，它还是更能接受周杰伦只有两个孩子的事实，这就是模型的prior</strong>。训练语料肯定也对prior产生了影响：<code class="language-plaintext highlighter-rouge">three</code>这个token仍有一定的概率出现，但是并不高。</p> <p>如果你做过sft/alignment的话，你也许会有类似的发现：<strong>预训练模型本身是有自己的知识的，有时候它会拒绝SFT阶段提供的知识</strong>。举个例子：如果pre-train模型认为周杰伦就是只有两个孩子，那你在SFT阶段即使添加了少量周杰伦有三个孩子的数据，在1个epoch的训练之后，任何类似的变体问题，<strong>甚至是原始的SFT数据直接作为输入去测试，模型并不能展现“看一遍就会”的超强拟合能力</strong>，而会在”二”和“三”的选择处出现分歧，通常符合原始知识的token “二”概率会更高一些。当然，如果你多训练一两个epoch，你会发现模型接受了这个知识，不过在变体问题上仍然表现不佳。</p> <h5 id="总结认知">总结认知</h5> <ul> <li>对于模型没有的知识，搜索返回的信息起到提供可能答案的作用。从这个角度来说，<strong>长窗口技术杀不死RAG，但是长窗口是对RAG中使用的trick（比如前文提到的chunk、filter、rerank）的一种降维打击。</strong>RAG本身立足于新的知识、隐私等考虑，这点上与长窗口的场景没有重合，但长窗口本身会使得RAG的pipeline越来越简单。</li> <li>对于模型已有的知识，RAG技术本不应该使用，因为一个理想中的“好”模型应该能够在见到知识（即使是长尾知识）后拥有很好的学习效果。但是，就如我们前面的case中指出的，现阶段比较靠谱的模型，即使强如GPT-4，仍然会存在知识冲突的问题（因为新知识对应的语料相比原始知识对应的语料，占比会更少）。在这种场景下，<strong>RAG本身起到的作用除了众所周知的提供答案、降低幻觉之外，更多的是帮助模型”回想起”正确的回答(修正prior)，或者是引导模型从多个冲突的知识中选择某一个知识回答</strong>。</li> </ul> <h4 id="沧海遗珠">沧海遗珠</h4> <ul> <li>token visualizer其实最早是为了分析推理过程中的幻觉而开发出来的一个工具，而幻觉本质上是高概率的错误token。很多bad case，关键的token采样正确了后面的回答就会自动修正了，所以每次面临badcase，我常常会有一种“明明我只要保持其它一切不变，只去修正某个token的概率就好了，但却不知如何去做”的无力感。</li> <li>作为一个偏好高密度信息的人，我对于大模型的期待在于：<strong>模型能够生成搜索引擎知识之外的内容</strong>，不然上限不会超越搜索引擎。目前的RAG或者说大模型的死穴之一就是：<strong>无法产生搜索引擎中存在，但无法检索到的深度知识。</strong>举一个最简单的例子：我试过问大模型“深度学习中，算子级别的数值稳定trick有哪些，列出具体的算子和对应的trick”或者类似意思的问题，没有任何一个模型能够回答出softmax、log-sum-exp这些内容，多数都在范范地谈论调整学习率、梯度裁剪和初始化相关的内容。但如果你直接问softmax的数值稳定技巧，所有的大模型都知道为防止overflow使用的减去最大值的trick。所以，能看出来：<strong>模型本身拥有每个单点知识，但是因为这些单点知识在网络上并没有人整合过，模型没有见过，搜索也搜索不到，所以很难生成出来</strong>。当大模型可以生成这种信息的时候，就可以说完全超越了搜索引擎。每个人就可以探索自己“不知道自己不知道”的领域了。</li> </ul> <h4 id="connect-the-dots">Connect the dots</h4> <p>关于RAG的内容在这个小节之前已经讨论完了，接下来我想分享的是在写visualizer时候的一些新的感悟，只对技术感兴趣的同学可以跳过这一小节。</p> <p>之前有一段时间在地铁上很无聊，加上计划开始写blog，所以“不务正业”去看了一些CSS和前端相关的内容。当时看起来这件事情与深度学习毫不相干，虽然学到了新的知识，但除了知识之外则是“毫无收益”：不会带来任何薪资上的涨幅，也很难和搞算法的同事聊这些东西，我一个CSS苦手也很难在未来某天靠这三脚猫的功夫谋一个饭碗。</p> <p>但是，当我在twitter上看到可视化ppl的<a href="https://twitter.com/thesephist/status/1617747154231259137">post</a><sup>[12]</sup>的时候，我很自然就想到了类似的方法也可以用来可视化推理过程，在搜索了gradio的文档之后，结合之前掌握到的一点粗浅的前端知识，我认识到了写一个类似的简单demo是完全可行的。如果没有那段“不务正业”的时光的话，或许我就会错过这个有趣的事情了。站在现在来看，好像所有之前的时刻都是为了这个时刻做准备，但是在学那些知识的时候，我并没有意识到这一点，也一定不会想到这一点。那一刻，我唯一能做的，只有做好手头的事情。</p> <p>突然地，我想起了乔布斯的那段著名的演讲，虽然很多人更喜欢最后的“Stay Hungry. Stay Foolish.”，但我想起的却是前面的一段：<strong>You can’t connect the dots looking forward, you can only connect them looking backwards. So you have to trust that the dots will somehow connect in your future</strong>. You have to trust in something - your gut, destiny, life, karma, whatever.</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/knowledge_experience-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/knowledge_experience-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/knowledge_experience-1400.webp"/> <img src="/assets/blog/knowledge_experience.png" width="750" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> ​ </div> <p><strong>相信那条线。</strong></p> <h4 id="reference">Reference</h4> <p><strong>[1]</strong> <a href="https://github.com/FateScript/token_visualizer">Token visualizer - github</a><br/> <strong>[2]</strong> <a href="https://arxiv.org/pdf/2312.10997.pdf">Retrieval-Augmented Generation for Large Language Models: A Survey</a><br/> <strong>[3]</strong> <a href="https://arxiv.org/pdf/2401.15884.pdf">Corrective Retrieval Augmented Generation</a><br/> <strong>[4]</strong> <a href="https://arxiv.org/pdf/2310.11511.pdf">Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection</a><br/> <strong>[5]</strong> <a href="https://arxiv.org/pdf/2309.12288.pdf">The Reversal Curse: LLMs trained on “A is B” fail to learn “B is A”</a><br/> <strong>[6]</strong> <a href="https://arxiv.org/pdf/2403.00758.pdf">Mitigating Reversal Curse in Large Language Models via Semantic-aware Permutation Training</a><br/> <strong>[7]</strong> <a href="https://arxiv.org/pdf/2403.13799.pdf">Reverse Training to Nurse the Reversal Curse</a><br/> <strong>[8]</strong> <a href="https://www.lesswrong.com/posts/qmQFHCgCyEEjuy5a7/lora-fine-tuning-efficiently-undoes-safety-training-from">LoRA undoes safety training - LessWrong</a><br/> <strong>[9]</strong> <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">OpenAI models 信息 - 官方文档</a><br/> <strong>[10]</strong> <a href="https://en.wikipedia.org/wiki/Jay_Chou">周杰伦的维基百科</a><br/> <strong>[11]</strong> <a href="https://arxiv.org/pdf/2404.10198.pdf">How faithful are RAG models? Quantifying the tug-of-war between RAG and LLMs’ internal prior</a><br/> <strong>[12]</strong> <a href="https://twitter.com/thesephist/status/1617747154231259137">twitter上@thesephist的可视化工具</a></p>]]></content><author><name></name></author><category term="engineering"/><category term="LLM"/><summary type="html"><![CDATA[搜索(R)是如果增强生成(AG)的]]></summary></entry><entry><title type="html">送你一把大师剑</title><link href="https://fatescript.github.io/blog/2023/learning/" rel="alternate" type="text/html" title="送你一把大师剑"/><published>2023-12-29T15:59:00+00:00</published><updated>2023-12-29T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2023/learning</id><content type="html" xml:base="https://fatescript.github.io/blog/2023/learning/"><![CDATA[<p>这篇文章我们来聊聊上个blog没聊完的一个话题：学习。</p> <p>很多人看到这个词可能会“唯恐避之不及”，不过先别急着结束，本文不会介绍“如何成为一个时间管理大师”，这些东西是“技”而非“道”。比起介绍复杂的学习方法，我觉得《小王子》里面的一段话更符合我对于这篇博客的定位：<strong>如果你想造一艘船，不要抓一批人来搜集材料，不要指挥他们做这个做那个，你只要让他们渴望大海</strong>。</p> <p>甚至让人们渴望大海这个定位也不太对，人天性中其实是充满好奇心且乐于求知的，所以这篇博客应该是 make “渴望大海” great again ：）</p> <h4 id="实用对于学习的异化">“实用”对于“学习”的异化</h4> <p>就像是资本会异化“劳动”的概念，对于实用性（功名）的过度渴求也会扭曲人对于“学习”的理解。很多人对于学习的概念扔停留在“看专业书/考试/做题”这样的一个范围内，但实际上，不管是玩游戏、钓鱼、修家电，还是唱、跳、rap、篮球…本质上都是一个学习过程。<strong>一件事情，只要你能随着接触次数的增多，变得越来越“好”了，那就是在学习</strong>。</p> <p>前一段时间在Switch上玩<a href="https://www.douban.com/game/26973112/">空洞骑士</a>，从我的体验上来说，这种游玩本质上也是一种学习：拿着缓慢升级的骨钉，挑战越来越难的boss（问题），失败的次数越多，就越能积累起来“哦，看起来不能这么打boss（解决问题）”的经验，直到有一刻你放下手柄，高兴地手舞足蹈：“Oh, Yes! 我做到了！”。即使在游戏的后期，面对更难的boss而战败，也很容易淡然面对，因为历史的经验在告诉你：<strong>随着练习的次数越来越多，操作中的失误会越来越少，打败boss，只是花时间去练习的问题</strong>。</p> <p>所以，事实上的”学习”是非常宽泛的一件事，你今天比昨天转笔转的更顺畅了，投篮的姿势标准了一些，都是在学习，都是值得开心和骄傲的事情。</p> <h4 id="浇灌你的树根">浇灌你的“树根”</h4> <p>国内的文化是相对注重实用性的，即使是学习，也总是希望学习最实用的东西。实用本身并没有问题，但是很容易滋生“功利”的心态。很多正常的人类需求会因为“没有用”/“看不到明显的收益”而被压抑。比如最经典的“xxx好有啥用，又不能当饭吃”就是这种思想的产物。但是，实际上这种想法是很有毒的，很多事情，不应当以是否很快有回报来衡量价值。</p> <p>王阳明在《传习录》中有一段话：“立志用功，如种树然。方其根芽，犹未有干; 及其有干，尚未有枝；枝而后叶，叶而后花、实。<strong>初种根时，只管栽培灌溉，勿作枝想，勿作叶想，勿作花想，勿作实想。悬想何益？但不忘栽培之功，怕没有枝叶花实？</strong>”这句话在我看来是根治功利心态的良药。只要专心浇灌树根，就不用太担心树没有枝叶。</p> <p>何恺明博士在港中文的<a href="https://cutv.cpr.cuhk.edu.hk/detail/1572?lang=zh_tw&amp;t=%E4%BD%95%E6%84%B7%E6%98%8E%E5%8D%9A%E5%A3%AB-2023%E5%B9%B4%E7%8D%B2%E7%8D%8E%E4%BA%BA%E5%AD%B8%E8%A1%93%E5%A0%B1%E5%91%8A%E6%9C%83">技术报告</a>上也曾经表达过类似的理念，在回答“如何找到最佳论文级别的研究主题”时，他说道：“<strong>I don’t care about publications, I just care why this problem behaves like that. I just care how can I solve this problem</strong>.” 虽然很多人可能会认为这是给科研人员的一碗鸡汤，但我觉得只有根本上认可这个思想的人才能做出最杰出的科研成果。文章引用只是树上的果实，对问题的好奇心才是树的根。把事情做扎实了，后面的事情就是水到渠成的。而<strong>过度关心果的人得不到果，即使侥幸摘得果实，也难以长久。</strong></p> <h4 id="好奇心与长期主义">好奇心与长期主义</h4> <p>前面讲的是在心态上如何正确认识学习，如何看待学习的“果”。这个小节想要聊的，是浇灌“树根”。</p> <p>第一个要说的，就是学习的“树根”是什么。在我看来，<strong>所有学习的本质都是好奇心</strong>。而好奇心人皆有之，<strong>不同的人只是好奇不一样的东西</strong>。比如我老婆会好奇化妆师是怎么把假睫毛粘的那么好的，我会好奇一个算法是如何从原理上work的，经常打篮球的人会想知道如何进攻/防守，喜欢做饭的人会思考热锅冷油和热锅热油对烹饪食材有什么区别。所以，与其耗费精力去功利地学习你不好奇的东西，不如早日找到你真正好奇的事情，然后专心地把这件事情做好。</p> <p>第二个要说的，是应当按照怎样的方法浇灌树根，也即学习的原则。<strong>原则，其实就是我们面临选择时一定会遵循的行事方式</strong>。对于学习，原则很简单：<strong>长期主义</strong>。这个原则的伟大之处在于，一旦你深刻地理解了长期的力量，你就不再那么在乎短期的得失或状态的起伏。下面我们就举两个例子来看一下如何运用这个原则。</p> <ul> <li>如果你在短期高强度地学习之后的夜晚，考虑第二天要怎么度过，这时你有两个选择：选择自己喜欢的方式放松一下，或者第二天继续自律地坚持学习。很多“自律狂魔”可能会push自己仍然坚持学习，但基于长期主义的原则来看，放松一下也很好，因为学习并不是意味着短时间内的突击，不必追求一时的快。当然能自律坚持下去固然很好，但放松一下，本身也是学习的一环。苏轼去世前留下来的四个字“<strong>着力即差</strong>”，要表达的就包含了这层意思。通俗一点来说：<strong>太过用力的人，跑不远</strong>。</li> <li>假设你非常想要阅读一本书，但你的工作非常忙，以至于很难抽出一整块时间去看，那么应该怎么办？这个是我一个朋友的真实问题，他很好奇在工作很忙的情况下，我从哪里挤出来的时间去看书的。其实回答也很简单：即使你的工作再忙，一天读两页，坚持一年你也能看下一本700页的书籍了，而大部分书籍，甚至都只有两三百页。但是，在繁忙的工作节奏下，一年能真正好好看完一本书的，又有几个人呢？<strong>人往往会低估自己长期能完成的事情，又过于高估自己能在短期内能完成的事情</strong>。因为多数人希望的是在一个星期、半个月的时间内看完200页的一本书，这种短时高强度的学习，是注定无法长期坚持的。我很喜欢《<a href="https://book.douban.com/subject/6811366/">禅与摩托车维修艺术</a>》里面的一句话：<strong>当你做某件事的时候，一旦想要求快，就表示你再不关心它，而想去做别的事</strong>。所以抱怨没有时间去做某事的人，可能真的是很忙，但更可能的是，没有坚持做事应当长期的原则。如果觉得时间不够而去花精力去学习时间管理技巧，就是在错误的道路上越走越远了。</li> </ul> <p>用我最喜欢的一条生活原则来结束这一小节：<strong>Fast is slow, slow is fast. Less is more, more is less.</strong></p> <h4 id="见贤思齐">见贤思齐</h4> <p>因为工作的原因，在公司遇到了很多在学习上很棒的引路人，所以这个部分，我打算聊一聊对他们的观察，稍稍地“见贤思齐”一下。</p> <ul> <li> <p>如果你有幸找到一个master级别的程序员请教问题，你会发现他们的头脑才是真正强大的编译器。举一个实际发生的例子：</p> <p>我：“G大师，我做了a，然后抛出了一个诡异的错误，在google上搜索过了，好像没有人遇到过这个issue”</p> <p>G大师：“你做了a，那有发现b现象吗？”</p> <p>我：“没有，但是有c现象”</p> <p>G大师：“嗯，这就有点奇怪，你是不是除了做a之外还做了d”</p> <p>我：“嗯有的，还能发现e现象”</p> <p>G大师：“那就对了，你检查一下f，问题应该在f中”</p> <p>等我回到工位上，不过一会儿就能把问题给解决掉了。</p> <p>这个经历是非常具有启发性的，你会发现<strong>他的大脑就像是一个隐马尔可夫模型，有一些internal state存在他的脑海里，他的语言只是那个隐状态对应的输出而已</strong>。</p> <p>如果说我们应该从这个对话中学习到什么内容，核心绝对不在于知道了以后遇到b现象时去检查所有的f，这种东西只是单纯的经验，是观察到的数据。我更倾向于去想：<strong>如何通过一个数据，学习一个pattern，学习如何进行internal state的建模</strong>。很多书店里面介绍的“刻意练习”书籍，其实就是就是在讲对于internal state的觉察，不然只是机械地学到一个“遇到了d应该去检查f”的经验。这些经验只是树上的果实，并非树根。</p> </li> <li> <p>除去G大师，还有上手任何事情都超级快的L。当我还在翻文档的时候，L就已经写出来成型的demo了。有一次我问他：“你是怎么做到不管做啥都能这么快就开始上手的？”，他跟我分享了一些理念，对于当时的我算是醍醐灌顶，这里也稍稍分享一下：</p> <ul> <li><strong>任何形式的学习，都可以分成两部分，枯燥的和创造的</strong>。以Gilbert Strang的<a href="https://www.bilibili.com/video/BV18K4y1R7MP/">线性代数课程</a>为例，课程中介绍的思想是创造的，所以你学起来很快很上瘾，但是去解线性代数题目的过程是机械的、枯燥的，所以你实践起来去算特征值/特征向量会有一些痛苦。</li> <li>在具备基础之后，<strong>为了快速上手某个内容，学习任何新东西，抓住主线就行</strong>。对于领域里面的关键思想，集中精力很快就能学完。而旁枝末节和非本质性的知识内容，给实践去敲打就行，<strong>不要浪费精力去学习翻一下手册就能查到的小技巧</strong><em>。</em>比如你已经有了一定的编程基础，让你学习一个新的语言，就是把声明、循环、判断等这些基本的东西弄明白，而深入的内容，只有好奇心和实践能教给你。<strong>相比快速上手，长期的实践才是本质困难的东西</strong>。</li> <li>当然，快速学习完最核心的东西之后，你也可能会飞速地失去对于这个领域的好奇心，后续也没有什么实践的动力，这个时候放弃后续也没关系。<strong>如果这件事真的重要，它迟早会回来找你</strong>。</li> </ul> </li> </ul> <h4 id="后记">后记</h4> <p>这篇文章其实聊了很多理念/“道”层面的东西，所以读起来也许会有一些说教感，不过，超出这篇文章之外，下面的内容才是我核心想要表达的东西。</p> <p><strong>未必有所成才能好好活着，学习、梦想也不是什么了不起的东西</strong>，看着天空、压压马路、在肚子饿时，还能吃到爱人给你做的热气腾腾的饭，这些同样是人于世间的顶级快乐，可惜以前的我从未深刻地明白。</p> <p>此文献给我过世的姥姥，她是个不认识字的文盲，没有什么文化，也讲不出来啥大道理。<strong>从我记事的时候起，她就是一个只会嘟囔嘴的农村老太太。可是她也同样伟大</strong>。</p> <p><strong>存在着就是伟大</strong>。</p> <p><strong>Viva La Vida（生命万岁）</strong></p>]]></content><author><name></name></author><category term="reflection"/><category term="reflection"/><summary type="html"><![CDATA[于旷野中驰骋]]></summary></entry><entry><title type="html">我在贵司这三年</title><link href="https://fatescript.github.io/blog/2023/work-and-think/" rel="alternate" type="text/html" title="我在贵司这三年"/><published>2023-09-01T15:59:00+00:00</published><updated>2023-09-01T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2023/work-and-think</id><content type="html" xml:base="https://fatescript.github.io/blog/2023/work-and-think/"><![CDATA[<h5 id="前言">前言</h5> <p>前段时间刚刚离开了工作三年的前公司。离职之前，有幸能和诸多深交的好友们聊一聊对于工作的看法、经验与教训。在回家的地铁上，逐渐地产生了将这些思考写下的想法。内容很多，关于对技术的认知，关于工作环境，与君共享、共勉。</p> <h5 id="常怀谦虚与敬畏之心警惕技术自大主义"><strong>常怀谦虚与敬畏之心，警惕技术自大主义</strong></h5> <p>越是在技术环境非常友好的地方，就越容易找到top2毕业的天之骄子、ACM竞赛中的传说大佬和论文引用轻松过千的领域专家，这些“天之骄子”是公司水平的保证，但也不可避免地引入了某种程度的“文人相轻”与“恃才傲物”。</p> <p>这种心态往往是潜藏在这样的话语中：</p> <ul> <li>“用户是愚蠢的，体会不到我们设计的精妙”</li> <li>“这篇paper做的很简单，不知道为啥有这么大的影响力”</li> <li>“这么简单的事情，就不要发paper说出来了</li> <li>“这个结论我们早就知道了，只是觉得太简单没好意思往外说”</li> </ul> <p>当然说话的人的本意如何我们无法得知，但是或许换成下面的话，会更合适一些：</p> <ul> <li>“用户偶尔会犯傻，所以我们需要一些符合直觉的设计来减少这种事情的发生”</li> <li>“这篇paper做的简单又具有很大的影响力，那么一定有某种我没意识到重要的东西需要我去学习”</li> <li>“这篇文章阐述了一些对我来说显而易见的东西，说明我可能对社区的水平欠缺一定的了解”</li> <li>“对于这种已经知道的结论，虽然不能写成paper，但可以写成blog或者report，对于其他人也是有帮助的”</li> </ul> <p>这种言语的改变绝非简单的“高情商”，而是根植于一种谦逊的心态，深信“别人身上一定有一些我可以借鉴的地方”的想法。我之前遇到过一个实习生，他无意间透露的想法让我印象深刻：“resnet的想法很简单，我也能想到，只是我不知道如何包装而已”。但其实，resnet中提出的方法，后面藏着的对于事物抽丝剥茧式的认知，绝非简单一朝一夕、一拍脑袋能够想起来的。<strong>解决问题的方法不是本质的东西，对事物的深刻认知才是</strong>。这种深刻的认知会促使我们找到优雅的解决方法，至于<strong>什么是优雅，我觉得就是“简单而有效”</strong>。</p> <p>在专心做技术的人群中，”技术自大主义”是一种很容易出现的现象，而在技术氛围良好的环境中，更容易受这种思想的影响。甚至我现在也经常需要使用“弱小和无知不是生存的障碍，傲慢才是”来警醒自己。也许有些人或者公司的行为你站在高层视角并不认可，但是这并不妨碍对方在某些领域取得成功。<strong>就像是拼多多，你当然不必认可它玩弄人性的各种伎俩，但你要认可它对于在十八线县城中生活的人的深刻洞察</strong>。</p> <h5 id="工程是scale的艺术"><strong>工程是scale的艺术</strong></h5> <p>如果有人问工程的核心是什么，很多工程师的回答可能是“trade-off（权衡）”。但是在读了<a href="https://book.douban.com/subject/34875994/">Software Engineering at Google</a> 之后，我的答案就变成了“scale”。</p> <p>因为我英文水平有限，不知道这个“scale”应该如何翻译会比较信达雅一些，所以我觉得举一些例子更合适：</p> <ul> <li>规模上的scale：如果你做的工作只有一小部分人会用到，那就是scale很小，而如果很多同行/公司会用你写的代码（比如一些重要的开源软件如numpy），那就是scale很大。</li> <li>时间上的scale：如果你做的工作只在很短的时间被使用/产生影响，那就是scale很小，而如果你完成一次工作就可以在很长的时间内重复使用/产生影响，那就是scale很大。</li> </ul> <p>“scale”是一个非常神奇的东西，三个房子可能只是三个房子，三千个房子可能就可以称之为城镇了，城镇就会自然涌现出医院、学校、商厦和马路。同理，一个工程项目里面的一行代码，也不仅仅是一行代码那么简单了。</p> <p>这种基于“scale”的思考模式在很大程度会影响你的行为:</p> <ul> <li>当你意识到你频繁需要使用一个东西的时候，那么是时候优化你和他人的使用体验了，因为频繁的使用就意味着时间上的scale。比如你写了一个工具，但是隔三差五就会有人跑来这个工具是如何使用的，那么最好将使用的流程写成文档，因为文档本身是随着时间scale的，你写一次文档，就可以节省很多次的说明。再比如我一直维护了一份自己使用的<a href="https://github.com/FateScript/dotfiles">dotfiles</a>，原因之一就是因为在不同的机器上配置自己的环境就属于一个随着时间scale的东西，而在这些dotfiles中我又推崇vim-like的操作习惯，甚至在vscode和chrome浏览器中也要使用vim的插件，因为vim的理念非常契合scale的思考模式：相比画画（写代码），画家（程序员）在画布（编辑器）上停留、浏览与思考的时间更长（scale更大）。</li> <li>看待大型的工程项目会更偏向scale的视角，会考虑<strong>这个项目到底scale了什么，怎么做这种scale</strong>。举个例子，pytorch是深度学习领域重要的开源仓库，但是其核心就是对于可以反向传播的算子的scale，换言之，就是如何保证可以加入众多的算子，而整个自动求导系统不会崩溃。为了解决这个问题，torch团队设计了的精妙的Tensor和Autograd体系，引入了算子的registry / dispatch等机制。再比如商汤的mm系列的框架，研究的就是模型的scale，考虑的就是如何很容易地将不同的模型集成入仓库。</li> </ul> <h5 id="管理team-work与环境">管理、team work与环境</h5> <p>因为组织架构的调整，在前公司<del>被迫</del>换了几个不同的组，也算体会过不同的环境之间的差异。所以这一个部分我们聊一下组织与管理。毕竟个人的力量有限，再强的个体也无法处理所有的问题，写出大型codebase中的每一行代码。</p> <p>第一个要聊的是“什么是团队”，下面是我目前对于团队的理解：<strong>团队不是简单的人的集合，团队中个人的强并不代表团队的强大</strong>。按照上面提到的scale的说法，<strong>团队不是人的scale</strong>。就像是打篮球一样，巨星抱团未必能有立竿见影的效果，“化学反应”才是关键，竞技体育中的核心目的是赢球而不是某个人的得分高。在一个所有人都很厉害的团队，很容易出现互不认可然后“大路朝天，各走一边”的情况，有人戏称这种情况叫做“聚是一坨x，散是满天星”。担任这样的团队的leader，<strong>技术实力的强弱与否只是一个评价维度，相较来说，能够协调处理各方的冲突，使团队focus共同的目标，是leader更重要的能力</strong>。</p> <p>第二个要聊的是团队的环境。我工作期间有幸跟过一些从比较简单、原始的状态开始的项目。现在回看，该走的弯路、该踩的坑，一个都不会少。<strong>初期犯错，是为了后期犯更少、更严重的错</strong>。作为对比，有些技术非常好的leader可能会不认可他们管辖的人（通常也是执行具体任务的人）的提议，希望项目能够少走一些弯路。这种不认可有很多表现形式，可能是只给你灌输他认可的想法，偏离这种想法的路径都是歪门邪道；可能是默许你做一些自己的创新，但是对于你在新路径上的发现采取不管不顾的态度。虽然后者的态度能让员工保留一些自主性，但是在一个鼓励创新与探索的环境里，这两种态度都是十分有毒的，因为“试错，是走向正确道路的最佳实践，有一些学费是不得不交的。”对于leader来说，指导固然很关键，但是<strong>建立鼓励尝试、对错误宽容（但及时反思）的团队氛围的作用要远大于指导</strong>。好友xxr说过一句话：一个leader的最大成功就是，团队没有leader也能持续良好地运转下去。<strong>员工很多时候需要的是催化剂，而不是导航员</strong>。</p> <p>第三个要聊的是我的一些观察，可能不对，但是在我短暂的职业生涯中，能找到一些具体的例子。第一个观察是：<strong>团队内的信息流动速度，会很大程度上影响每个人的成长速度</strong>。乐于分享的组织，通常员工的个人素质会更高一些，员工之间的私人关系也会更好，也更容易产生化学反应。第二个观察是：<strong>越是有口碑的管理者，曾经的部下就越容易追随他/她</strong>。这种追随的形式可能很多，比如：在离职之后仍然保持一定强度的联系、缺乏人手的时候愿意主动帮忙推荐一些朋友。最后一个观察，或者说经验，就是：在一个团队里面，<strong>如果你总觉得一个事情需要有人去做，那么你也许是最适合的人选</strong>。比如你可能会觉得文档需要有人写了，那很可能你就是最适合写文档的人（说明你对缺乏文档这件事情的耐受度比较低）；或者你觉得频繁做某项任务太麻烦了，应该写一个工具简化流程，当你在团队里表达了这个想法之后，很可能最后就是交给你来完成。</p> <h5 id="后记">后记</h5> <p>本来到这里还想写一下过去3年学到的“如何学习”作为第四个部分的，但鉴于篇幅可能比较长，先写到这里，抽空再开一个blog聊一聊学习和成长的问题吧。</p> <p>无论前公司将来会走向何方，走了何种道路，我都会感恩在过去的时间遇到的每个人和事。关于谦卑、关于工程、关于团队，我又比三年前的我理解地更深了一些。</p> <p>可叹，有些时候，就是“<strong>人生南北多歧路，君向潇湘我向秦</strong>”。</p>]]></content><author><name></name></author><category term="reflection"/><category term="reflection"/><summary type="html"><![CDATA[人生短暂，故事长存]]></summary></entry><entry><title type="html">LLMs as Markov Chain</title><link href="https://fatescript.github.io/blog/2023/LLM-markov-chain/" rel="alternate" type="text/html" title="LLMs as Markov Chain"/><published>2023-06-03T15:59:00+00:00</published><updated>2023-06-03T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2023/LLM-markov-chain</id><content type="html" xml:base="https://fatescript.github.io/blog/2023/LLM-markov-chain/"><![CDATA[<h4 id="前言">前言</h4> <p>几个月之前，<a href="https://karpathy.ai/">Andrej Karpathy</a> 发布了一个<a href="https://twitter.com/karpathy/status/1645115622517542913">推特</a><sup>[1]</sup>，给出了一个看待语言模型(language model，下称LM)行为的新视角：LM可以看作有限状态的马尔可夫链（finite-state Markov Chain）。在最近一段时间和LLM（large language model）的交互过程中，以这个马尔可夫链的视角作为基础，笔者对于LLM的一些行为有了进一步的理解与认知。写这篇文章，一方面是为了分享Karpathy的观点，另一方面则是帮助大家从实践的视角进一步理解/预测语言模型的一些行为。</p> <p>本文会在第一个部分介绍LM为什么可以被看作是一个Markov chain；之后会从以这个视角进一步展开，聊一聊Markov chain视角下的Prompt Engineer、In-Context Learning以及一些LM展现出来的有趣特性。</p> <h4 id="karpathy的观点">Karpathy的观点</h4> <p>为了照顾一些初学者，这个部分会介绍地尽量详细一些，已经了解为什么LM可以被看作是Markov chain的读者可以跳过这个部分。</p> <p>想要体验最原汁原味的介绍可以移步Karpathy写的<a href="https://colab.research.google.com/drive/1SiF0KZJp75rUeetKOWqpsA8clmHP6jMg?usp=sharing">colab</a><sup>[2]</sup>。</p> <h5 id="context-length与tokenizer">context length与tokenizer</h5> <p>LM从本质上来看，就是接受一堆文字作为输出，然后不断预测下一个文字的模型。为了通俗一些，我们举一个例子，假设我们有一个窗口大小为4的LM，这个LM接受的输入给LM这样一段话“今天天气”，LM就会预测下一个字是“真”，接着我们把“真”放入“今天天气”后面，同时保持窗口大小不变，LM接受到的输入就是“天天气真”，LM就会预测下一个字是“好”，接着把好放在之前的句子后面，以此类推，最后我们就可以得到“今天天气真好。”的输出。</p> <p>因为LM的输入需要是固定的长度，为了统一，我们就会称这个固定长度为<strong>context length</strong>。上文的例子中的LM的context length的就是4，这也就意味着这个LM一次性接受4个词的输入，并且预测下一个词是什么。</p> <p>但是，LM是无法接受文字作为输入的，对于所有的LM来说，都需要<strong>tokenizer</strong>将文字输入转换成token。以<a href="https://arxiv.org/pdf/2302.13971.pdf">LLaMA</a>的<a href="https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/tokenizer.py#L13">tokenizer</a>为例子，在不考虑bos(begin of sentence，即&lt;s&gt; )和eos(end of sentence，即&lt;/s&gt;)符号的情况下，句子”Hello world”会被转换成 \([15043, 3186]\) 的输入，之后这个输入就可以被LM接收，从而预测下一个单词。</p> <p>下面的code给出了一个具体的示例来方便理解：</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">&gt;&gt;&gt;</span> <span class="n">tokenizer</span><span class="p">.</span><span class="nf">encode</span><span class="p">(</span><span class="sh">"</span><span class="s">Hello world</span><span class="sh">"</span><span class="p">,</span> <span class="n">bos</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">eos</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="o">&lt;&lt;&lt;</span> <span class="p">[</span><span class="mi">15043</span><span class="p">,</span> <span class="mi">3186</span><span class="p">]</span>

<span class="o">&gt;&gt;&gt;</span> <span class="n">tokenizer</span><span class="p">.</span><span class="nf">decode</span><span class="p">([</span><span class="mi">15043</span><span class="p">,</span> <span class="mi">3186</span><span class="p">])</span>
<span class="o">&lt;&lt;&lt;</span> <span class="sh">'</span><span class="s">Hello world</span><span class="sh">'</span>
</code></pre></div></div> <p>为了说明简单，我们后文中的token都采用数字来表示，这样我们就可以把LM的输入看作是一个数字序列，而LM的输出则是一个关于全部token的分布。而tokenizer能够处理的字符集的大小，我们称之为<strong>vocab_size</strong>。</p> <p>假设token只有0和1两种（vocab_size为2），context_length 为2，\(\rightarrow\) 表示数据流向，LM推理[1, 0]输入的过程可以表示为：</p> \[[1, 0] \rightarrow LM \rightarrow [P(0) = 40\%, P(1) = 60\%]\] <p>当然这里预测为0和1的概率是随便给的，只是为了方便理解。</p> <h5 id="vocab_size与context-length决定了马尔可夫链的状态空间">vocab_size与context length决定了马尔可夫链的状态空间</h5> <p>考虑一个最最简单的LM，我们称之为baby-GPT，这个LM的context length为3，token只有[0, 1]两种，那么这个LM的全部状态空间就可以表征为 \([0, 1]\) 的3次笛卡尔积， 也就是说这个baby-GPT的状态空间大小为 \({vocab\_size}^{context\_length} = 2^3 = 8\)。 具体来说，所有的状态空间为 \([0, 0, 0]\), \([0, 0, 1]\), \([0, 1, 0]\), \([0, 1, 1]\), \([1, 0, 0]\), \([1, 0, 1]\), \([1, 1, 0]\), \([1, 1, 1]\)。</p> <p>考虑一个baby-gpt的特定状态，此处我们以 \([0, 0, 1]\) 为例，将这个状态作为baby-gpt的输入， 对应的输出的形式则类似于 \([P(0) = 45\%, P(1) = 55\%]\)，代表下一个token是0或者1的概率。 将这个过程对应到马尔可夫链的角度，我们可以认为 \([0, 0, 1]\) 状态可以转移到 \([0, 1, 0]\) 和 \([0, 1, 1]\) 两个后继状态，转移概率分别为 \(45\%\) 和 \(55\%\) 。</p> <p>下图给出了baby-GPT在初始状态下的每个状态和对应的转移概率。</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/baby_gpt_init-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/baby_gpt_init-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/baby_gpt_init-1400.webp"/> <img src="/assets/blog/baby_gpt_init.png" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">baby-gpt初始转移概率</figcaption> </figure> ​ </div> <h5 id="从markov-chain视角看训练">从Markov Chain视角看训练</h5> <p>假设开始训练这个baby-GPT，需要训练的数据序列为”111101111011110”，则baby-GPT实际的训练数据则为：</p> <p>训练数据 01: \([1, 1, 1] \rightarrow 1\)<br/> 训练数据 02: \([1, 1, 1] \rightarrow 0\)<br/> 训练数据 03: \([1, 1, 0] \rightarrow 1\)<br/> 训练数据 04: \([1, 0, 1] \rightarrow 1\)<br/> 训练数据 05: \([0, 1, 1] \rightarrow 1\)<br/> 训练数据 06: \([1, 1, 1] \rightarrow 1\)<br/> 训练数据 07: \([1, 1, 1] \rightarrow 0\)<br/> 训练数据 08: \([1, 1, 0] \rightarrow 1\)<br/> 训练数据 09: \([1, 0, 1] \rightarrow 1\)<br/> 训练数据 10: \([0, 1, 1] \rightarrow 1\)<br/> 训练数据 11: \([1, 1, 1] \rightarrow 1\)<br/> 训练数据 12: \([1, 1, 1] \rightarrow 0\)</p> <p>在正常训练了模型之后，我们可以得到一个训练好的baby-GPT的权重，此时baby-GPT的状态转移概率相对初始版本已经发生了变化，如下图所示：</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/baby_gpt_trained-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/baby_gpt_trained-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/baby_gpt_trained-1400.webp"/> <img src="/assets/blog/baby_gpt_trained.png" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">baby-gpt训练后的转移概率</figcaption> </figure> ​ </div> <p>从上图不难看出，相比初始状态，训练后的baby-GPT在 \([0, 0, 1]\) 状态下更容易生成转移到 \([0, 1, 1]\) (概率从 \(55\%\) 提升到 \(78\%\) )。实际上整个baby-GPT相比初始状态，更容易预测下一个token是1，这也符合训练数据特点：1的数量远远大于0。</p> <p>到这里，基本上大家可以理解为什么LM本质上是一个Markov chain了，也能根据上面提供的例子理解数据是如何影响这个Markov chain的了。</p> <h4 id="lm-as-markov-chain的一些性质">LM as Markov chain的一些性质</h4> <p>本章节讨论LM作为Markov chain会有哪些有趣的性质，以及这些性质对LM的训练和使用有什么启发。</p> <p><strong><span style="color:red">声明：这些性质未必是LLM作为Markov chain一定存在的性质，更多是我个人的看法和符合直觉的思想实验，欢迎提出不一样的看法。 </span></strong></p> <h5 id="性质与启发">性质与启发</h5> <ul> <li>第一个性质肯定是<strong>稀疏性</strong>，这个也很直觉，虽然Markov chain的状态非常多，但是大部分状态之间几乎没有转移概率，因此这个Markov chain是非常稀疏的。<br/> 这个性质对LM的训练的启发在于：如果想要让LM在特定场景能够输出一些不常用的字，比如“你”字后面跟个“铋”字，单纯更多地使用“你铋”的话应该是治标不治本的，因为模型是根据context进行转移的，而“你”字只是最后一个token，以“你”字结尾的状态数过于巨大，并不是简单通过加几个训练样本就能解决的。</li> <li>第二个性质在于状态数的<strong>指数爆炸性</strong>，随着LM的vocab_size和context_length增大，Markov Chain中的状态数是几乎以指数倍增长的，在原始训练数据分布不变的情况下，模型建模的难度是减少的。但如果想要更好的效果，模型需要投喂的数据量可能也需要进行某种（感觉是指数的？）形式的scale。至少从直觉上来说，应该存在某种数据规模和context_length之间的scaling law。除此之外，状态的指数爆炸性也在某种程度上能解释为什么LLM会存在涌现能力，很可能状态数达到了某种足够多的状态之后，完成某个任务的知识的建模起来更加容易了。</li> <li>第三个性质是Markov chain中<strong>同构现象普遍存在</strong>。这个同构现象是将Markov chain看作一个图，而这个大图中的部分子图是同构的。比如考虑一个同时具有英文和中文能力的LM，“I want to go home”和“我想回家”在tokenizer看来是完全没有任何关系的两句话（因为tokenizer encode出来的结果完全不一样），但是我们如果站在Markov chain的视角去看这两句话在图里面的结构，很可能是非常相似的。<strong>不同语言的相似语义保证了这种同构现象的存在</strong>。<br/> 这个性质对于LM训练的启发在于：如果想要提升LM在某种语言（比如中文）的效果，单纯堆中文语料甚至不一定比中英混合语料更有效。解决A空间中的问题或许可以采用解决空间B中的问题 + 映射回A空间的方式。</li> </ul> <h5 id="展开看lm的特性">展开看LM的特性</h5> <p>首先想聊的是模型的能力(ability)。</p> <p>在最早看到CoT(Chain of Thought)相关的paper<sup>[4]</sup>以及“Let’s think step by step.”<sup>[5]</sup>的魔法提示词(Prompt)之后，我一度不是很理解：通过更改Prompt的方式，模型就比原来更有可能产生期望的输出结果，而且很可能模型在训练阶段都没怎么见到过这个Prompt。 这件事放在计算机视觉领域类比一下，就相当于找到了一个新的图像增强策略，这个增强策略在训练阶段没有使用过，但是却能够在所有的模型上有效提升效果。</p> <p>到这里就会引出一个新的问题：如何界定一个模型有解决某类问题的能力？ 从Markov chain的视角来看，问题本身就是这个链上的一个状态集合A（称之为问题状态集，之所以是个集合是因为同一个问题有很多表示形式，在链上的状态数必然不止一个），而我们期望的答案也是这个链上的一个状态集合B（称之为答案状态集）。只要在这个链上从A到B的转移概率不为0，那么我们就可以认为模型是具有解决这个问题的能力的。用公式表达就是：</p> \[P_{LLM}(B|A) &gt; 0 \Rightarrow {LLM有能力解决问题A}\] <p>所以说，如果模型在某种Prompt的提示下产生了期望的输出，那么我们就能认为模型本身是具有能力的，只不过被Prompt激发了出来。</p> <p>有趣的是，在计算机视觉领域，据我所知还没有类似Prompt这种可以激发单一模态的视觉模型能力的方法（大部分没有训练过的数据增强策略都对效果有负面影响）。</p> <p>其实这个视角同样可以套用到人的身上：<strong>如果一个人存在解决某个问题的可能性（解决问题的概率大于0），那我们就能认为这个人是有能力解决这个问题的</strong>。</p> <p>其次想要聊的是LM对于拥有更大信息量的数据的偏好性。</p> <p>这点其实也很好理解，同样长度的数据，如果LM在看过数据之后，对应的Markov chain中的转移概率没有发生太大的变化，那么这个数据训练与否对LM并没有太大的影响，反倒是一些能够改变Markov chain中转移概率的数据起到的作用更大。换句话说，在训练过程中，LM更倾向于受到具有更大信息量的数据的影响，因为这些数据可以帮助Markov chain建立状态之间的链接，修正状态之间的转移概率。</p> <p>套用到人身上，就是已知的信息看再多遍也很难有明显的提升，提升自己的能力靠的是寻求新的知识与挖掘看待旧知识的新视角。 而<strong>模型的预训练就像人类的学习一样，都是在初始链接的基础上，不断更新状态之间的转移概率，建立更强的状态间的连接</strong>。</p> <h4 id="新视角下的old-things">新视角下的Old things</h4> <p>这个部分我们会站在Markov chain的新视角来看待一些“旧事物”。</p> <h5 id="prompt-engineering">Prompt Engineering</h5> <p>在<a href="https://scholar.google.com/citations?user=dCa-pW8AAAAJ&amp;hl=en">Lilian Weng</a>介绍Prompt Enginerring的<a href="https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/">blog</a><sup>[3]</sup>以及<a href="https://www.promptingguide.ai/">Prompting Guide网站</a>中介绍了很多prompt engineering的方法， 有一些方法对新人来说也许有一些反直觉或者tricky，比如把问题中的”Q”换成”Question”、一开始给LM设定一些特定的角色玩cosplay、 给几个实际样例(few-shot)，但是考虑到前文所述我们仅仅是要从问题状态A集找到一条到回答状态集B的一条转移路径，这些方法也就不难理解。</p> <p>以把原始问题中的”Q”换成”Question”这个trick为例，其所表达的就是下面一个朴素的公式：</p> \[P_{LLM}(Answer|Q \; type \; Prompt) &lt; P_{LLM}(Answer|Question \; type \; Prompt)\] <p>从Markov chain视角来看上面的公式：“question”状态转移到答案状态的概率，要比和“q”状态转移到答案的概率更高。</p> <p><strong>通过在输入端改变问题，进而改变问题状态集，并且最终提升转移到答案状态集的概率，这是在Markov chain视角下对于Prompt Engineering的新视角</strong>。</p> <h5 id="in-context-learning">In-Context Learning</h5> <p>In-Context Learning（下称ICL），简单来说，就是类似下面的一种场景：</p> <pre><code class="language-C++">评论： 这个电影太烂了。 态度：消极。
评论： 我好喜欢这个电影。 态度：
</code></pre> <p>模型则会根据输入对应产生输出。</p> <pre><code class="language-C++">评论： 这个电影太烂了。 态度：消极。
评论： 我好喜欢这个电影。 态度：积极。
</code></pre> <p>在华盛顿大学和meta研究ICL为什么能work的<a href="https://arxiv.org/abs/2202.12837">paper</a><sup>[6]</sup>里（或者参考斯坦福大学的<a href="http://ai.stanford.edu/blog/understanding-incontext/">blog</a><sup>[7]</sup>）， 研究人员探究了一下到底是输入、输出还是输入-输出的匹配更加重要（参考下图）。</p> <p><img src="/assets/blog/icl.png" alt="ICL" width="700"/></p> <p>文章给出了一个非常有信息量的实验：<strong>输入-输出的匹配并没有想象中那么重要</strong>。也就是说，即使将原有标签随机修改，比如上面的示例修改成<code class="language-plaintext highlighter-rouge">评论： 这个电影太烂了。 态度：积极。</code>，模型仍然能够产生正确的输出。关键在于保持输入和输出本身的一致性。</p> <p>结合Markov chain来看待ICL：<strong>通过指定问题的输入和输出空间，使得LM在一个固定的子图上游走，使得模型更有可能产生正确的输出</strong>。最妙的是，根据实验来看，这个游走过程是不受之前的错误状态引导的。</p> <h5 id="cot">CoT</h5> <p>“Let’s think step by step.”<sup>[5]</sup>的魔法Prompt也被称为Zero-shot CoT（Chain of Thought），在使用了这样的prompt之后，模型更容易沿着分解问题的思路解决问题，从而在一些逻辑推理类的任务上产生分步输出，进而获取更接近真实答案的输出。</p> <p>在Markov Chain中，<strong>“Let’s think step by step” 和问题的中间步骤关联，而中间步骤状态相比没有任何输出的状态转移到答案的概率更高。</strong>这样来看，想出这个prompt也是很需要insight的。</p> <h4 id="random-but-not-random">Random, but not random</h4> <p>这个其实是我观察到的一个很有趣的现象，很多时候LLM是能够理解随机的，但是行为上却绝对做不到最真实的随机。其实从Markov chain的视角来看，这个事情是很容易理解的， 但是可能你去问一些ChatGPT的用户，他们或许也并不能回答这个问题：<strong><span style="color:red">如果要求ChatGPT完成如下的任务：“从A，B，C，D中随机选择一个”，那么ChatGPT这样的LM能否从做到统计意义上的随机？</span></strong></p> <p>答案很显然：<strong><span style="color:red">肯定不能，而且LM几乎确定不能原生解决这样的问题</span></strong>。要证明这个问题也很简单，以这个“ABCD”的例子 来说明，仅仅考虑当前Markov chain的状态S，考虑后续输出为“A”，“B”，“C”，“D”的四种状态A、B、C、D，要做到统计意义上的随机，Markov chain就一定需要满足下面的公式（不考虑temperature这些因素）：</p> \[P(A|S) = P(B|S) = P(C|S) = P(D|S) = 25\%\] <p>注意公式里面的ABCD只是一个状态的合集，也就是像“A”和“A.“都是A这个集合中的一个元素，所以说LM几乎确定不能解决这个问题。 但是如果引入插件的思想，由LM做控制器来判断需要执行<code class="language-plaintext highlighter-rouge">random.choice(["A", "B", "C", "D"])</code>函数，这个问题就非常容易解决了。</p> <h4 id="citation">Citation</h4> <p>如果觉得有帮助，欢迎引用这篇blog：</p> <div class="language-shell highlighter-rouge"><div class="highlight"><pre class="highlight"><code>@article<span class="o">{</span>wang2023LLM,
  title   <span class="o">=</span> <span class="s2">"LLMs as Markov Chain"</span>,
  author  <span class="o">=</span> <span class="s2">"Wang, Feng"</span>,
  journal <span class="o">=</span> <span class="s2">"fatescript.github.io"</span>,
  year    <span class="o">=</span> <span class="s2">"2023"</span>,
  month   <span class="o">=</span> <span class="s2">"Jun"</span>,
  url     <span class="o">=</span> <span class="s2">"https://fatescript.github.io/blog/2023/LLM-markov-chain/"</span>
<span class="o">}</span>
</code></pre></div></div> <h4 id="reference">Reference</h4> <p><strong>[1]</strong> <a href="https://twitter.com/karpathy/status/1645115622517542913">Karpathy的twiiter</a><br/> <strong>[2]</strong> <a href="https://t.co/8jdceMLpqy">介绍LLM as Markov Chain的Colab</a><br/> <strong>[3]</strong> <a href="https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/">Lilian Weng介绍Prompt Engineering的blog</a><br/> <strong>[4]</strong> <a href="https://arxiv.org/abs/2201.11903">Chain-of-Thought Prompting Elicits Reasoning in Large Language Models</a> <br/> <strong>[5]</strong> <a href="https://arxiv.org/abs/2205.11916">Large Language Models are Zero-Shot Reasoners</a><br/> <strong>[6]</strong> <a href="https://arxiv.org/abs/2202.12837">Rethinking the Role of Demonstrations: What Makes In-Context Learning Work?</a><br/> <strong>[7]</strong> <a href="http://ai.stanford.edu/blog/understanding-incontext/">How does in-context learning work?</a></p>]]></content><author><name></name></author><category term="engineering"/><category term="math"/><category term="LLM"/><summary type="html"><![CDATA[看待LLM的新视角]]></summary></entry><entry><title type="html">copybara：关于我只是想在仓库间做代码搬运这件事</title><link href="https://fatescript.github.io/blog/2022/copybara/" rel="alternate" type="text/html" title="copybara：关于我只是想在仓库间做代码搬运这件事"/><published>2022-09-14T15:59:00+00:00</published><updated>2022-09-14T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2022/copybara</id><content type="html" xml:base="https://fatescript.github.io/blog/2022/copybara/"><![CDATA[<h4 id="前言">前言</h4> <p><strong>在国内的技术社区，几乎没有任何blog阐述过有关source move问题的解决方案。</strong>一般来说，受到这个问题困扰的技术人员还是比较少的，但是相信随着国内开源社区的不断壮大，遇到类似问题的人会逐渐变多，本文仅仅是抛砖引玉地给出了我现在采用的一种解决方案和最佳实践，核心侧重于提供思路。</p> <h4 id="source-move难在哪儿">Source move难在哪儿</h4> <p>如果你是一个公司内部项目的maintainer，而这个项目又需要开源的话，那么如何进行代码的内外同步就是一个令人头疼的问题。</p> <p>这个问题之所以令人头疼，核心原因在于开源过程存在的一些限制和要求：</p> <ul> <li>公司内部的代码通常包含一些特殊的code，比如为公司某个产品专门设计的策略、内部管理使用的issue/jira/wiki link等内容，这部分code是无论如何不能泄漏出去的。</li> <li>因为git可以找回历史，所以开源出去的repo和公司内部的repo本质上还不是一个repo，最起码一部分的git commit object是不一样的</li> <li>内部的repo和外部的repo都需要进行迭代开发，所以在开发过程中保持多个repo的同步也是一个问题，不然一段时间之后就等着代码分叉吧…</li> <li>git在commit object中保存的是全量文件，而不是增量更新（这一点很多人都会产生误解），所以除了filter-branch这类操作之外，很难做到内部仓库删除掉一些文件就成了开源版本的仓库</li> </ul> <p>同步流程需要很容易集成进CI/CD中，尽量减少人力消耗</p> <p>上面的限制，决定了： <strong>用户需要在源代码层级上做代码的迁移</strong>。也就是对文件做一些读写操作，比如重新组织或者删除了内部的一部分code，就成了外部开源的code。</p> <p>因为这个流程本身基本上是文件的读写和对git object的操作，所以想象中造一个轮子应该不存在本质难的问题。不过本着“现成的轮子能满足需求就不自己造”的原则，我还是简单调研了一些开源的工具。</p> <h4 id="why-copybara">Why copybara</h4> <p>github上能找到的现成的工具只有两个：<a href="https://github.com/facebook/fbshipit">fbshipit</a>和<a href="https://github.com/google/copybara">copybara</a>，FAIR下面很多知名的codebase比如pytorch、detectron2都是用fbshipit做代码的同步，而google的<a href="https://opensource.google/documentation/reference/thirdparty/maintenance">open source best practice</a>中则提及了copybara这个工具。</p> <p>fbshipit本身支持的迁移方式比copybara多了一了hg，这是对我来说唯一的好处。而缺点则比较多：可以参考的文档比较少（甚至现在master上的文档应该是n个版本之前的）；配置文件使用hack（PHP的一个dialect）；example简直约等于没有。总体上，感觉fbshipit更像是一个fb内部使用的服务，如果自己要搞会比较麻烦。</p> <p>作为对比，copybara最大的缺点就是安装起来比较heavy，但是除此之外都要比fbshipit好得多：配置文件使用<a href="https://github.com/bazelbuild/starlark">starlark</a>（python的一个dialect）；文档虽然不多，但是够用；公司内部有 <a href="https://www.zhihu.com/people/megengine-bot">MegEngine Bot</a> 写的一些example作为参考。本着尽快上手的原则，就选择了copybara作为裁剪工具。</p> <h4 id="how-to-use-copybara">How to use copybara</h4> <p>copybara的本质是基于正则表达式做匹配，通过匹配规则来修改代码。所有对外的repo都需要有一个SoT(source of truth)，也就是唯一的truth。当同步出现了问题，需要做判定以谁为准的时候，SoT就是标准答案。</p> <p>考虑到一些可能的使用场景，我在本文的下个部分给出了一些实践中使用的transform，仅仅是提供一些参考，如果不是对细节很感兴趣的话可以直接跳过下个部分直接到best practice部分。而如果想要知道更细节的内容的话，可以参考这个手把手教你使用copybara的blog. 对code的transform</p> <h5 id="删除多行code">删除多行code</h5> <p>要删除多行code，就需要标记在何处开始，以及在何处结束。这里我们使用BEGIN/END-INTERNAL作为对应的标记。</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="c1"># BEGIN-INTERNAL
</span><span class="nf">internal_only_code</span><span class="p">()</span>
<span class="c1"># END-INTERNAL</span></code></pre></figure> <p>下面这个是官方提供的一个example，需要注意的是：从re的规则来看，它会把BEGIN-INTERNAL标记之前的空行一并删除掉。</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">core</span><span class="p">.</span><span class="nf">replace</span><span class="p">(</span>
    <span class="n">before</span> <span class="o">=</span> <span class="sh">"</span><span class="s">${x}</span><span class="sh">"</span><span class="p">,</span>
    <span class="n">after</span> <span class="o">=</span> <span class="sh">""</span><span class="p">,</span>
    <span class="n">multiline</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
    <span class="n">regex_groups</span> <span class="o">=</span> <span class="p">{</span>
    	<span class="sh">"</span><span class="s">x</span><span class="sh">"</span><span class="p">:</span> <span class="sh">"</span><span class="s">(?m)</span><span class="se">\\</span><span class="s">n*^.*BEGIN-INTERNAL[</span><span class="se">\\</span><span class="s">w</span><span class="se">\\</span><span class="s">W]*?END-INTERNAL.*$</span><span class="sh">"</span><span class="p">,</span>
    <span class="p">},</span>
<span class="p">)</span></code></pre></figure> <h5 id="删除单行code">删除单行code</h5> <p>实际当中只删除一行code的情况还是比较常见的，为了读起来友好一些，使用一个DELETE-THIS-LINE作为标记，code读起来像是下面这种</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">f</span><span class="p">():</span>
    <span class="n">x</span> <span class="o">=</span> <span class="sh">"</span><span class="s">Hello</span><span class="sh">"</span>
    <span class="n">x</span> <span class="o">+=</span> <span class="sh">"</span><span class="s">world</span><span class="sh">"</span>  <span class="c1"># DELETE-THIS-LINE
</span>    <span class="k">return</span> <span class="n">x</span></code></pre></figure> <p>对应的transform example：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">core</span><span class="p">.</span><span class="nf">replace</span><span class="p">(</span>
    <span class="n">before</span> <span class="o">=</span> <span class="sh">"</span><span class="s">${line}</span><span class="sh">"</span><span class="p">,</span>
    <span class="n">after</span> <span class="o">=</span> <span class="sh">""</span><span class="p">,</span>
    <span class="n">multiline</span> <span class="o">=</span> <span class="bp">True</span><span class="p">,</span>
    <span class="n">regex_groups</span> <span class="o">=</span> <span class="p">{</span>
        <span class="sh">"</span><span class="s">line</span><span class="sh">"</span><span class="p">:</span> <span class="sh">"</span><span class="s">(?m)</span><span class="se">\\</span><span class="s">n^.*?DELETE-THIS-LINE.*$</span><span class="sh">"</span><span class="p">,</span>
    <span class="p">},</span>
<span class="p">)</span></code></pre></figure> <h5 id="增加单行code">增加单行code</h5> <p>除了删除单行，我们还需要增加某些单个行，实际中类似：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">f</span><span class="p">():</span>
    <span class="c1"># ADD-THIS-LINE var = "Hello"
</span>    <span class="k">pass</span></code></pre></figure> <p>因为python中缩进是有语意的，所以我们在使用re进行匹配的时候，就需要考虑空格带来的影响。对应的transform如下：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">core</span><span class="p">.</span><span class="nf">replace</span><span class="p">(</span>
	<span class="n">before</span> <span class="o">=</span> <span class="sh">"</span><span class="s">${indent}${symbol}${code}</span><span class="sh">"</span><span class="p">,</span>
    <span class="n">after</span> <span class="o">=</span> <span class="sh">"</span><span class="s">${indent}${code}</span><span class="sh">"</span><span class="p">,</span>
    <span class="n">regex_groups</span> <span class="o">=</span> <span class="p">{</span>
    	<span class="sh">"</span><span class="s">indent</span><span class="sh">"</span><span class="p">:</span> <span class="sh">"</span><span class="s">(?m)^</span><span class="se">\\</span><span class="s">s*</span><span class="sh">"</span><span class="p">,</span>
        <span class="sh">"</span><span class="s">symbol</span><span class="sh">"</span><span class="p">:</span> <span class="sh">"</span><span class="s">#.*ADD-THIS-LINE</span><span class="se">\\</span><span class="s">s*</span><span class="sh">"</span><span class="p">,</span>
        <span class="sh">"</span><span class="s">code</span><span class="sh">"</span><span class="p">:</span> <span class="sh">"</span><span class="se">\\</span><span class="s">S.*$</span><span class="sh">"</span><span class="p">,</span>
    <span class="p">},</span>
<span class="p">)</span></code></pre></figure> <h5 id="增加多行code">增加多行code</h5> <p>这个case有一些复杂，如果为了省事的话可以反向思考：只要能把多行的注释删除掉就行了。不过增加多行code的写法让原来的code显得很冗长，不是特别推荐。 具体的例子可以参考下面的example：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="c1"># BEGIN-INTERNAL
</span><span class="sh">"""</span><span class="s">
# END-INTERNAL
external_only_code1()
external_only_code2()
# BEGIN-INTERNAL
</span><span class="sh">"""</span>
<span class="c1"># END-INTERNAL</span></code></pre></figure> <h5 id="删除移动文件">删除/移动文件</h5> <p>如果仅仅是删除文件，只需要在dest file list中使用exclude排除文件即可，而移动文件本身还是对应core中的一个操作，参考如下code：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">core</span><span class="p">.</span><span class="nf">move</span><span class="p">(</span><span class="sh">"</span><span class="s">foo/bar_internal</span><span class="sh">"</span><span class="p">,</span> <span class="sh">"</span><span class="s">bar</span><span class="sh">"</span><span class="p">)</span></code></pre></figure> <h4 id="处理外部pr">处理外部PR</h4> <p>因为裁剪出去的code是开源的版本，自然免不了需要处理PR（Pull Request）的问题。可以预见，如果在外部的repo上合并了一个PR，就会和SoT原则发生冲突。而对于内部仓库使用copybara之后，就会强制覆盖外部的commit，相当于git push -f操作，这对于任何一个项目来说来说都是不能接受的。</p> <p>官方推荐的流程其实是向下面这种，但是实际中我们采取了一个不太一样的解决方案。</p> <div class="language-text highlighter-rouge"><div class="highlight"><pre class="highlight"><code>  +--------------------+             +--------------------+
  |                    |             |                    |
  |  External Repo     |             |    External PR     +&lt;---+ contributor
  |                    |             |                    |      opens a PR
  |                    |             |                    |
  +--------^-----------+             +--------+-----------+
           |                                  |
    New commits are                  Changes shadowed as an
    pushed via copybara              internal PR via copybara
           |                                  |
  +--------+-----------+             +--------v-----------+
  |                    |             |                    |
  |   Internal Repo    +&lt;------------+  Internal PR       |
  |                    |   CI runs   |                    |
  |                    |   &amp;         +--------------------+
  +--------------------+   Team member reviews and merges
</code></pre></div></div> <h5 id="patch-integrate">patch integrate</h5> <p>我们第一个需要解决的问题是如何将外部PR引入到内部并且顺利裁剪。从功能上来说，copybara其实是支持从外向内的流程的（定义一个从外部向内部的workflow就可以了，也就是上图中的内容），但实际上engine组 @MegEngine Bot 已经趟出来一个更方便的方法：通过打patch（git format-patch）然后am（git am）的方式将外部的PR引入到内部的仓库中。因为外部文件本身是transform之后产生的，这个过程中会偶尔有一些conflict需要处理，不过总体来说不会有太大的问题，问题通常出现在integrate之后的对外裁剪过程。</p> <p>如果PR的target branch已经包含了对应的commit，那么github/gitlab平台会自动标记PR为merged状态。但是判断两个commit是否相同的逻辑是commit object的sha1 hash是否相同，而这个hash由很多因素决定，比如source tree、commit message等（详情可以参考这个gist），而copybara在裁剪的时候会默认在commit message中生成GitOrigin-RevId（也就是对应的内部commit，参考下图），还会修改对应的时间戳信息，这就导致了commit的hash发生了变化。如果此时直接把裁剪后的branch push到github，PR就不会自动merge，但是PR中的文件diff已经没了。 ​</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/copybara_commit-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/copybara_commit-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/copybara_commit-1400.webp"/> <img src="/assets/blog/copybara_commit.png" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">copybara commit message</figcaption> </figure> ​ </div> <h5 id="fake-merge">fake merge</h5> <p>为了能merge一些外部的PR，copybara本身会根据特定的label判定commit是不是patch integrate（这个label默认是COPYBARA_INTEGRATE_REVIEW，详情参考文档），在commit message中包含了label的情况下就不会生成GitOrigin-RevId，同时会根据commit message自动merge对外的PR，这样相当于在原始commit后增加一个merge commit，在merge commit中会包含GitOrigin-RevId等内容，外部的branch就会包含PR的commit，PR也就会变成merged状态。</p> <h5 id="其他解决思路">其他解决思路</h5> <p>特意去看了一下pytorch和detectron2处理外部PR的方法，发现对于外部提交，pytorch/detectron2全部都会close掉，之后由bot告知contributor对应的commit id。这就是因为引入了新的commit message和更改了commit时间戳，导致无法和外部commit hash对齐，只能全部close掉。这种做法的最大好处就是不会生成merge commit，整个source tree看起来就非常干净。</p> <h4 id="大概率会踩的坑">大概率会踩的坑</h4> <p>前面基本上把使用copybara的一些常用的方法介绍了一下，这里插一些集成copybara到CI/CD过程中遇到的坑，期望能够节省使用者的时间。</p> <ul> <li>在CI/CD中copybara如果没有上传成功，就会认为是异常退出（exit code为非0数值）。比如你重跑了一下workflow，job就会神奇地fail掉。在我第一次把copybara workflow加入到CI/CD中的时候，找了半天CI/CD异常退出的bug。最后为了让exit code为0，写成了如下形式：</li> </ul> <figure class="highlight"><pre><code class="language-shell" data-lang="shell">copybara copy.bara.sky <span class="o">||</span> <span class="nb">echo</span> <span class="s2">"copybara failed"</span></code></pre></figure> <ul> <li>默认情况下，copybara裁剪的代码是从init commit到和git的远端同步的部分，所以当你本地commit了code后直接运行copybara并不能对代码进行裁剪。如果更新了sky文件但是不生效，多半也是因为忘了push到远端了。</li> </ul> <h4 id="吐槽">吐槽</h4> <p>copybara基本上满足了我对于source move的诉求，但是在有些场景下，使用起来还是不太方便，所以作为用户，在这里小小吐槽一下（我看开发团队bandwidth不太够的样子，就不发issue骚扰了）</p> <p>copybara本身基于re做匹配，这一点我认为是合理的，但是在处理匹配的代码的时候，完全可以做的更加动态一些。</p> <p>考虑正常代码的裁剪过程，其遵循如下的一个模式：找到匹配的pattern -&gt; 处理该pattern -&gt; 返回处理后的结果。处理pattern的过程可能是很动态的，而且这一步本质上就是对于字符串的各种变换方式，应该允许用户自己使用函数定义，既然starlark本身就是python的一个dialect，那么在其中写一些python的处理逻辑也是很正常的诉求。</p> <p>比如下面这种自定义transformation的写法：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">end</span><span class="p">):</span>
    <span class="c1"># transform code
</span>
<span class="n">core</span><span class="p">.</span><span class="nf">dynamic_process</span><span class="p">(</span>                                                   
    <span class="n">before</span> <span class="o">=</span> <span class="sh">"</span><span class="s">${start}${x}${end}</span><span class="sh">"</span><span class="p">,</span>
    <span class="n">regex_groups</span> <span class="o">=</span> <span class="p">{</span><span class="sh">"</span><span class="s">start</span><span class="sh">"</span><span class="p">:</span> <span class="n">start_regex</span><span class="p">,</span> <span class="sh">"</span><span class="s">x</span><span class="sh">"</span><span class="p">:</span> <span class="n">x_regex</span><span class="p">,</span> <span class="sh">"</span><span class="s">end</span><span class="sh">"</span><span class="p">:</span> <span class="n">end_regex</span><span class="p">},</span>
    <span class="n">func</span> <span class="o">=</span> <span class="n">f</span><span class="p">,</span>
<span class="p">)</span></code></pre></figure> <h4 id="best-practice">best practice</h4> <p>在最后，结合engine团队的反馈和我自己的一些实践，给一些目前的best practice。</p> <ul> <li>尽可能做到仅仅需要删除内部独有的文件就可以让外部的code正常跑起来，这样copybara是配置起来最简单的，而且在code review的时候会少很多心智负担。如果能通过refactor把内部和外部的code区分的比较干净，本身也说明项目的复杂度相对比较低</li> <li>项目中copybara的标记过多是一个red flag，表明耦合度可能过高。如果一个文件中出现了过多的copybara中使用的标记，那么就应该考虑是否要将文件分拆。另外过多的copybara标记也更加可能导致开发人员贡献了一些code但是最终没有做code transform的现象，为后期工作埋雷。</li> <li>裁剪后的版本最好有一个单独的repo可以查看，这样release之前更容易发现问题，及时在内部修复。</li> <li>在CI/CD中最好diff一下move前后的内容，这一步也是为了给code review减少负担，防止reviewer在merge某个PR之后发现copybara做错了，再补交一个commit等类似现象等出现。</li> </ul> <h4 id="参考资料">参考资料</h4> <ol> <li><a href="https://github.com/google/copybara">copybara github</a></li> <li><a href="https://github.com/google/copybara/blob/master/docs/reference.md">copybara reference</a></li> <li><a href="https://github.com/bazelbuild/starlark">starlark</a></li> <li><a href="https://kubesimplify.com/moving-code-between-git-repositories-with-copybara">copybara intro</a></li> <li><a href="https://github.com/Olivr/copybara-action">copybara action</a></li> <li><a href="https://opensource.google/documentation/reference/thirdparty/maintenance">open source best practice</a></li> </ol>]]></content><author><name></name></author><category term="engineering"/><category term="code"/><summary type="html"><![CDATA[前言 在国内的技术社区，几乎没有任何blog阐述过有关source move问题的解决方案。一般来说，受到这个问题困扰的技术人员还是比较少的，但是相信随着国内开源社区的不断壮大，遇到类似问题的人会逐渐变多，本文仅仅是抛砖引玉地给出了我现在采用的一种解决方案和最佳实践，核心侧重于提供思路。]]></summary></entry><entry><title type="html">Tensor是如何让你的内存/显存泄漏的</title><link href="https://fatescript.github.io/blog/2022/tensor-memory-leak/" rel="alternate" type="text/html" title="Tensor是如何让你的内存/显存泄漏的"/><published>2022-05-20T15:59:00+00:00</published><updated>2022-05-20T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2022/tensor-memory-leak</id><content type="html" xml:base="https://fatescript.github.io/blog/2022/tensor-memory-leak/"><![CDATA[<h4 id="前言">前言</h4> <p>本文适合算法研究员/工程师阅读，如果你遇到奇怪的内存泄漏问题，说不定本文能帮你找到答案，解答疑惑。 虽然在大部分场景下，程序的内存泄漏都和数据息息相关。但是读完本文你就会了解，没有被正确使用的Tensor也会导致内存和显存的泄漏。</p> <h4 id="起源">起源</h4> <p>某次组会的时候，同事报告了一个很好玩的issue：我司某组的一个codebase出现了奇怪的泄漏现象，奇怪的点有以下几个方面：<br/> （1）不同的模型，内存/显存泄漏的现象不一样。比如A模型和B模型泄露的速度是不一样的<br/> （2）训练同一个模型的时候，如果在dataset中增加了数据量，相比不加数据，会在更早的epoch就把内存泄漏完。<br/> 是不是听起来现象非常离谱，本着”code never lies“的世界观，我开始探求这个现象的真正原因。</p> <h4 id="复现">复现</h4> <p>要想解决一个大的问题，首先就要降低问题的复杂度。最小复现代码是我们找问题的基础，而这个写最小复现代码的过程其实也是遵循了一定套路的，此处一并分享给大家：</p> <ul> <li>如果突然出现了历史上没有出现过的问题（比如在某个版本之后突然内存开始泄漏了），用git bisect找到 first bad commit（前提项目管理的比较科学，不会出现很多feature杂糅在一个commit里面；还有就是git checkout之后复现问题的成本不高）。如果bisect大法失效，考虑下面的复现流程。</li> <li>首先排除data的问题，也就是只创建一个dataloader，让这个loader不停地供数据，看看内存会不会涨（通常data是一系列对不上点、内存泄漏的重灾区）。</li> <li>其次排除训练的问题，找一个固定数据，不停地让网络训练固定数据进行训练/推理，看看是否发生泄漏。这一步主要是检查模型、优化器等组件的问题（通常模型本身不会发生泄漏，这一步经常能查出来一些自定义op的case）</li> <li>最后就是检查一些外围组件了。比如各种自己写的utils/misc的内容。这块通常不是啥重灾区。</li> </ul> <p>最后给出来我的最小复现（loguru可以换成print）：</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><table class="rouge-table"><tbody><tr><td class="gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
</pre></td><td class="code"><pre><span class="kn">import</span> <span class="n">torch</span>
<span class="kn">from</span> <span class="n">loguru</span> <span class="kn">import</span> <span class="n">logger</span>
<span class="kn">import</span> <span class="n">os</span>
<span class="kn">import</span> <span class="n">psutil</span>


<span class="k">def</span> <span class="nf">log_device_usage</span><span class="p">(</span><span class="n">count</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">):</span>
    <span class="n">mem_Mb</span> <span class="o">=</span> <span class="n">psutil</span><span class="p">.</span><span class="nc">Process</span><span class="p">(</span><span class="n">os</span><span class="p">.</span><span class="nf">getpid</span><span class="p">()).</span><span class="nf">memory_info</span><span class="p">().</span><span class="n">rss</span> <span class="o">/</span> <span class="mi">1024</span> <span class="o">**</span> <span class="mi">2</span>
    <span class="n">cuda_mem_Mb</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="nf">memory_allocated</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="mi">1024</span> <span class="o">**</span> <span class="mi">2</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="mi">0</span>
    <span class="n">logger</span><span class="p">.</span><span class="nf">info</span><span class="p">(</span>
        <span class="sa">f</span><span class="sh">"</span><span class="s">iter </span><span class="si">{</span><span class="n">count</span><span class="si">}</span><span class="s">, mem: </span><span class="si">{</span><span class="nf">int</span><span class="p">(</span><span class="n">mem_Mb</span><span class="p">)</span><span class="si">}</span><span class="s">Mb, gpu mem:</span><span class="si">{</span><span class="nf">int</span><span class="p">(</span><span class="n">cuda_mem_Mb</span><span class="p">)</span><span class="si">}</span><span class="s">Mb</span><span class="sh">"</span>
    <span class="p">)</span>


<span class="k">def</span> <span class="nf">leak</span><span class="p">():</span>
    <span class="n">use_cuda</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="nf">is_available</span><span class="p">()</span>
    <span class="n">val</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">).</span><span class="nf">cuda</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">log_iter</span> <span class="o">=</span> <span class="mi">20000</span>
    <span class="nf">log_device_usage</span><span class="p">(</span><span class="n">count</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>
    <span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
        <span class="n">value</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">).</span><span class="nf">cuda</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">val</span> <span class="o">+=</span> <span class="n">value</span><span class="p">.</span><span class="nf">requires_grad_</span><span class="p">()</span>
        <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span>
        <span class="k">if</span> <span class="n">count</span> <span class="o">%</span> <span class="n">log_iter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nf">log_device_usage</span><span class="p">(</span><span class="n">count</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>


<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="sh">"</span><span class="s">__main__</span><span class="sh">"</span><span class="p">:</span>
    <span class="nf">leak</span><span class="p">()</span>
</pre></td></tr></tbody></table></code></pre></figure> <p>试着运行一下，你就会发现你的内存和显存开始起飞了（内存泄漏的比显存更快一些），泄漏到一定程度，整个程序就会卡死，过一段时间就会被kill掉。作为对比，将<code class="language-plaintext highlighter-rouge">requires_grad_()</code>删掉（或者在后面加上<code class="language-plaintext highlighter-rouge">detach()</code>），你就可以看到没有泄漏发生的log了。</p> <p>写完了复现之后，同事问了我俩问题，大家也可以提前思考一下：</p> <ol> <li>为啥这个程序会出现内存/显存泄漏？</li> <li>为啥明明在gpu上的tensor会泄漏内存？</li> </ol> <h4 id="探索">探索</h4> <p>首先第二个问题很好理解，<strong>因为虽然在概念上，torch中的tensor是在gpu上的，但是也只是数据的storage在gpu上，除了在显存上存储的数据，tensor的一些其他信息（比如shape，stride和output_nr等）肯定也是要占据一定内存的。所以在cuda available的时候，内存和显存都会泄漏。</strong></p> <p>那么第一个问题是因为啥呢？我一时间也难以想明白，于是我打算直接通过torch的源码去找问题的答案。这个过程略长一些，想要看结论的读者可以直接跳到解惑部分。如果对torch内部的东西稍微感兴趣，可以继续看下去。 因为torch里面有很多code是生成出来的（有机会我们可以讲一讲torch的code gen），所以我们需要先编译一下torch（我用的commit hash是2367face）。因为写torch的cuda extension的时候，要使用Tensor就会需要 include &lt;ATen/ATen.h&gt;，以此为线索我最后定位到了一个叫做TensorBody.h的文件，通过fzf在torch/include/ATen/core下的TensorBody.h文件中找到了inplace add的定义，源码如下（torch中inplace都是在原来的名字后面加_，比如add和add_）。</p> <figure class="highlight"><pre><code class="language-c--" data-lang="c++"><span class="kr">inline</span> <span class="n">at</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span> <span class="n">Tensor</span><span class="o">::</span><span class="n">add_</span><span class="p">(</span><span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span> <span class="n">other</span><span class="p">,</span> <span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">Scalar</span> <span class="o">&amp;</span> <span class="n">alpha</span><span class="p">)</span> <span class="k">const</span> <span class="p">{</span>
    <span class="k">return</span> <span class="n">at</span><span class="o">::</span><span class="n">_ops</span><span class="o">::</span><span class="n">add__Tensor</span><span class="o">::</span><span class="n">call</span><span class="p">(</span><span class="k">const_cast</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&amp;&gt;</span><span class="p">(</span><span class="o">*</span><span class="k">this</span><span class="p">),</span> <span class="n">other</span><span class="p">,</span> <span class="n">alpha</span><span class="p">);</span>
<span class="p">}</span></code></pre></figure> <p>再通过<a href="https://github.com/ggreer/the_silver_searcher">ag</a>找<code class="language-plaintext highlighter-rouge">add__Tensor</code>的定义，最后在torch/csrc/autograd/generated文件夹下面的VariableTypeEverything.cpp文件找到定义。这个文件其实是多个VariableType_{0,1,2,3}.cpp开头的文件拼接成的。在VariableType_3.cpp中我们可以找到<code class="language-plaintext highlighter-rouge">add__Tensor</code>的定义。此处我们精简一下和我们的case相关的部分方便理解。</p> <figure class="highlight"><pre><code class="language-c--" data-lang="c++"><table class="rouge-table"><tbody><tr><td class="gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
</pre></td><td class="code"><pre><span class="n">at</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span> <span class="n">add__Tensor</span><span class="p">(</span><span class="n">c10</span><span class="o">::</span><span class="n">DispatchKeySet</span> <span class="n">ks</span><span class="p">,</span> <span class="n">at</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span> <span class="n">self</span><span class="p">,</span> <span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">Tensor</span> <span class="o">&amp;</span> <span class="n">other</span><span class="p">,</span> <span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">Scalar</span> <span class="o">&amp;</span> <span class="n">alpha</span><span class="p">)</span> <span class="p">{</span>
    <span class="k">auto</span><span class="o">&amp;</span> <span class="n">self_</span> <span class="o">=</span> <span class="n">unpack</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="s">"self"</span><span class="p">,</span> <span class="mi">0</span><span class="p">);</span>
    <span class="k">auto</span><span class="o">&amp;</span> <span class="n">other_</span> <span class="o">=</span> <span class="n">unpack</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="s">"other"</span><span class="p">,</span> <span class="mi">1</span><span class="p">);</span>
    <span class="k">auto</span> <span class="n">_any_requires_grad</span> <span class="o">=</span> <span class="n">compute_requires_grad</span><span class="p">(</span> <span class="n">self</span><span class="p">,</span> <span class="n">other</span> <span class="p">);</span>
 
    <span class="p">(</span><span class="kt">void</span><span class="p">)</span><span class="n">_any_requires_grad</span><span class="p">;</span>
    <span class="n">check_inplace</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">_any_requires_grad</span><span class="p">);</span>
    <span class="n">c10</span><span class="o">::</span><span class="n">optional</span><span class="o">&lt;</span><span class="n">at</span><span class="o">::</span><span class="n">Tensor</span><span class="o">&gt;</span> <span class="n">original_self</span><span class="p">;</span>
    <span class="n">std</span><span class="o">::</span><span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">AddBackward0</span><span class="o">&gt;</span> <span class="n">grad_fn</span><span class="p">;</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">_any_requires_grad</span><span class="p">)</span> <span class="p">{</span>
      <span class="n">grad_fn</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">AddBackward0</span><span class="o">&gt;</span><span class="p">(</span><span class="k">new</span> <span class="n">AddBackward0</span><span class="p">(),</span> <span class="n">deleteNode</span><span class="p">);</span>
      <span class="n">grad_fn</span><span class="o">-&gt;</span><span class="n">set_next_edges</span><span class="p">(</span><span class="n">collect_next_edges</span><span class="p">(</span> <span class="n">self</span><span class="p">,</span> <span class="n">other</span> <span class="p">));</span>
      <span class="n">grad_fn</span><span class="o">-&gt;</span><span class="n">other_scalar_type</span> <span class="o">=</span> <span class="n">other</span><span class="p">.</span><span class="n">scalar_type</span><span class="p">();</span>
      <span class="n">grad_fn</span><span class="o">-&gt;</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span><span class="p">;</span>
      <span class="n">grad_fn</span><span class="o">-&gt;</span><span class="n">self_scalar_type</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">scalar_type</span><span class="p">();</span>
    <span class="p">}</span>
    <span class="p">{</span>
      <span class="n">at</span><span class="o">::</span><span class="n">AutoDispatchBelowAutograd</span> <span class="n">guard</span><span class="p">;</span>
      <span class="n">at</span><span class="o">::</span><span class="n">redispatch</span><span class="o">::</span><span class="n">add_</span><span class="p">(</span><span class="n">ks</span> <span class="o">&amp;</span> <span class="n">c10</span><span class="o">::</span><span class="n">after_autograd_keyset</span><span class="p">,</span> <span class="n">self_</span><span class="p">,</span> <span class="n">other_</span><span class="p">,</span> <span class="n">alpha</span><span class="p">);</span>
    <span class="p">}</span>
    <span class="k">if</span> <span class="p">(</span><span class="n">grad_fn</span><span class="p">)</span> <span class="p">{</span>
        <span class="n">rebase_history</span><span class="p">(</span><span class="n">flatten_tensor_args</span><span class="p">(</span> <span class="n">self</span> <span class="p">),</span> <span class="n">grad_fn</span><span class="p">);</span>
    <span class="p">}</span>
    <span class="k">return</span> <span class="n">self</span><span class="p">;</span>
<span class="p">}</span>
</pre></td></tr></tbody></table></code></pre></figure> <p>这里我们顺便来看一下<code class="language-plaintext highlighter-rouge">add__Tensor</code>函数在干啥，<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/VariableTypeManual.cpp#L43-L64"><code class="language-plaintext highlighter-rouge">unpack</code></a>方法其实就是对tensor的一个检查，unpack后的code简单来说就是计算一下input tensor是否需要梯度（这个会影响到前向过程对于输出tensor的grad_fn的设置），如果需要梯度，就会进行图的构建（也就是设置tensor对应的一些属性），之后用dispatcher发送add的kernel，完成tensor的加法运算。torch中其他的op如sub，sigmoid等都是遵循一样的逻辑（因为torch里面前向过程创建图的逻辑是完全一样的，和具体的op类型无关，所以这些op才可以通过代码生成出来）。</p> <p>解释完了函数的逻辑，我们来重新看一下泄漏的问题。</p> <p>如果我们注释掉<code class="language-plaintext highlighter-rouge">grad_fn-&gt;set_next_edges(collect_next_edges( self, other ));</code> 或 <code class="language-plaintext highlighter-rouge">rebase_history(flatten_tensor_args( self ), grad_fn);</code> 这两行code中的任意一行，那么都不会出现内存/显存泄漏的现象，由此我们有理由怀疑是在构建动态图的过程中产生了内存泄漏的。</p> <p>又因为<code class="language-plaintext highlighter-rouge">rebase_history</code>是后面才被调用的，所以<code class="language-plaintext highlighter-rouge">set_next_edges</code>过程肯定只是出现泄漏的一个诱因，真正发生泄漏的位置肯定在后调用的位置，由此我们进一步来看<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/VariableTypeUtils.h#L90-L110"><code class="language-plaintext highlighter-rouge">rebase_history</code></a>的实际代码<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/variable.cpp#L142-L166">实现</a>。从源码逻辑来看，大部分是检查和确保一些属性的逻辑，核心在于<code class="language-plaintext highlighter-rouge">set_gradient_edge(self, std::move(gradient_edge));</code>这一句。由此，我们来看<code class="language-plaintext highlighter-rouge">set_gradient_edges</code>的逻辑，当然，为方便理解，下面的code做了一些精简（全部code的参考链接： <a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/variable.cpp#L234-L247"><code class="language-plaintext highlighter-rouge">set_gradient_edge</code></a>，<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/variable.cpp#L133-L140"><code class="language-plaintext highlighter-rouge">materialize_autograd_meta</code></a>，<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/variable.cpp#L311-L315"><code class="language-plaintext highlighter-rouge">get_auto_grad_meta</code></a>）</p> <figure class="highlight"><pre><code class="language-cpp" data-lang="cpp"><span class="kt">void</span> <span class="nf">set_gradient_edge</span><span class="p">(</span><span class="k">const</span> <span class="n">Variable</span><span class="o">&amp;</span> <span class="n">self</span><span class="p">,</span> <span class="n">Edge</span> <span class="n">edge</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">auto</span><span class="o">*</span> <span class="n">meta</span> <span class="o">=</span> <span class="n">materialize_autograd_meta</span><span class="p">(</span><span class="n">self</span><span class="p">);</span>
  <span class="n">meta</span><span class="o">-&gt;</span><span class="n">grad_fn_</span> <span class="o">=</span> <span class="n">std</span><span class="o">::</span><span class="n">move</span><span class="p">(</span><span class="n">edge</span><span class="p">.</span><span class="n">function</span><span class="p">);</span>
  <span class="n">meta</span><span class="o">-&gt;</span><span class="n">output_nr_</span> <span class="o">=</span> <span class="n">edge</span><span class="p">.</span><span class="n">input_nr</span><span class="p">;</span>
<span class="p">}</span>

<span class="n">AutogradMeta</span><span class="o">*</span> <span class="nf">materialize_autograd_meta</span><span class="p">(</span><span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">TensorBase</span><span class="o">&amp;</span> <span class="n">self</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">auto</span> <span class="n">p</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">unsafeGetTensorImpl</span><span class="p">();</span>
  <span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">p</span><span class="o">-&gt;</span><span class="n">autograd_meta</span><span class="p">())</span> <span class="p">{</span>
    <span class="n">p</span><span class="o">-&gt;</span><span class="n">set_autograd_meta</span><span class="p">(</span><span class="n">std</span><span class="o">::</span><span class="n">make_unique</span><span class="o">&lt;</span><span class="n">AutogradMeta</span><span class="o">&gt;</span><span class="p">());</span>
  <span class="p">}</span>
  <span class="k">return</span> <span class="nf">get_autograd_meta</span><span class="p">(</span><span class="n">self</span><span class="p">);</span>
<span class="p">}</span>

<span class="n">AutogradMeta</span><span class="o">*</span> <span class="nf">get_autograd_meta</span><span class="p">(</span><span class="k">const</span> <span class="n">at</span><span class="o">::</span><span class="n">TensorBase</span><span class="o">&amp;</span> <span class="n">self</span><span class="p">)</span> <span class="p">{</span>
  <span class="k">return</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">AutogradMeta</span><span class="o">*&gt;</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">unsafeGetTensorImpl</span><span class="p">()</span><span class="o">-&gt;</span><span class="n">autograd_meta</span><span class="p">());</span>
<span class="p">}</span></code></pre></figure> <p>看到这里，基本上熟悉pytorch中对于图定义的同学大概就能知道是什么原因了。关于pytorch中forward过程构建图的原理，可以参考官网的<a href="https://pytorch.org/blog/computational-graphs-constructed-in-pytorch/">blog</a>，作为一个基础概念，我们只需要了解：<strong>动态图就是在forward过程中进行图的“创建”，在backward过程完成图的“销毁”。</strong></p> <p>现在让我们回到数据结构中Graph（图）的概念。在一个自动求导系统中，我们可以将Graph中的<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/edge.h#L14">Edge</a>（边）简单地理解为一个tensor，Graph中<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/function.h#L99">Node</a>（节点）的概念理解为算子。比如在torch里写 <code class="language-plaintext highlighter-rouge">c = a + b</code>，其实就是表示有一个a 表示的Edge和一个b代表的Edge连接到一个add的Node（节点）上，这个Node又会连接到一个叫做c的Edge上（下面是一个用<a href="https://github.com/mermaid-js/mermaid">mermaid</a>画的一个示意图，其中Edge用矩形表示，Node用圆表示。不难看出，add就是一个入度为2，出度为1的Node）。</p> <div style="text-align: center"> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/tensor_graph-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/tensor_graph-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/tensor_graph-1400.webp"/> <img src="/assets/blog/tensor_graph.png" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture><figcaption class="caption">c = a + b的图表示</figcaption> </figure> ​ </div> <p>既然我们有了图，那么就需要有一些结构保存一部分基本的图信息，这些基本图信息会在自动求导（autograd）的时候使用。在torch中，AutogradMeta就是包含了诸如tensor的autograd历史、hooks等信息的结构，而导致我们内存/显存泄漏的罪魁祸首也正是这个<a href="https://github.com/pytorch/pytorch/blob/v1.10.1/torch/csrc/autograd/variable.h#L190">AutogradMeta</a>。 现在，我们已经知道memory实际上泄漏的是啥了。跳回我们写的code，结合gc机制，想一想问题1你是否知道了答案。</p> <h4 id="解惑">解惑</h4> <p>至此，我们基本上就可以把问题1解释清楚了：<strong>在Tensor的requires_grad为True的时候，Tensor的每次运算都会导致需要保存一份AutogradMeta信息，对应的Tensor也会被加入到计算图中。即使表面上来看你只是做了一些inplace add的操作，但是其实在torch内部，那个临时的Tensor已经进入到了图里，成为了图的一个Edge，且引用计数 + 1，自然是要占据空间的。如果你的Tensor不requires_grad，那么就是只是进行运算，不会有Meta等信息存在，那个暂时生成的Tensor就会引用计数清0被gc了，自然也不会有内存泄漏了。</strong> 除了问题1之外，结合上面介绍的内容，我们也能理解，下面一段非常pythonic的code在pytorch里面并不科学的原因。</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="n">total_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">dataloader</span><span class="p">:</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="nf">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
    <span class="n">total_loss</span> <span class="o">+=</span> <span class="n">loss</span>
<span class="n">total_loss</span><span class="p">.</span><span class="nf">backward</span><span class="p">()</span></code></pre></figure> <p>现在，让我们从最小复现代码回归到codebase，其实我给出的复现里面的代码中的value就是loss，很多时候炼丹师会想要看一下loss的均值/最大值等统计信息，经常会用一个meter保存历史信息，也就对应了复现代码里面的val。 很多奇怪的现象到此也就说的通了，比如不同模型泄漏速度不一样，就是因为不同的模型loss的数量是不一样的，泄漏的速度自然也是不一样的；再比如增加数据会使得同一个模型在更早的epoch到达OOM状态，是因为当数据增加的时候一个epoch内的iter数就会变多，自然会有在更早的epoch把内存泄漏完的现象；曾经能训练的模型加了数据之后也有可能因此变得无法训练。</p> <h4 id="后记">后记</h4> <p>也许下面这句话对炼丹师来说听起来有些反直觉，但我觉得还是有必要声明一下：<strong>无论python前端中tensor看起来是如何动态地进行运算，概念上计算图中的每个节点都无法被inplace修改。</strong></p> <p>在理解了本文要介绍的原理后，我们也可以轻易写一些reviewer看起来好像没啥问题的泄漏程序了（逃</p> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">leak</span><span class="p">():</span>
    <span class="n">use_cuda</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="nf">is_available</span><span class="p">()</span>
    <span class="n">val</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">).</span><span class="nf">cuda</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="n">torch</span><span class="p">.</span><span class="nf">rand</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">val</span><span class="p">.</span><span class="nf">requires_grad_</span><span class="p">()</span>  <span class="c1"># 比如这个requires_grad_是在某个地方偷偷加的
</span>    <span class="n">count</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="n">log_iter</span> <span class="o">=</span> <span class="mi">20000</span>
    <span class="nf">log_device_usage</span><span class="p">(</span><span class="n">count</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>
    <span class="k">while</span> <span class="bp">True</span><span class="p">:</span>
        <span class="n">val</span> <span class="o">+=</span> <span class="mi">1</span>  <span class="c1"># 这个1在torch里面会表示为一个cpu tensor
</span>        <span class="k">if</span> <span class="n">count</span> <span class="o">%</span> <span class="n">log_iter</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
            <span class="nf">log_device_usage</span><span class="p">(</span><span class="n">count</span><span class="p">,</span> <span class="n">use_cuda</span><span class="p">)</span>
        <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span></code></pre></figure> <p>为了更好的表示上述代码在执行过程中发生了什么，我用<a href="https://github.com/3b1b/manim">manim</a>写了一个动画来提供更直观的解释，放在结尾也是希望读者能在读完文章后，稍微让头脑休息一下吧：）</p> <figure> <picture> <source class="responsive-img-srcset" media="(max-width: 480px)" srcset="/assets/blog/manim.gif-480.webp"/> <source class="responsive-img-srcset" media="(max-width: 800px)" srcset="/assets/blog/manim.gif-800.webp"/> <source class="responsive-img-srcset" media="(max-width: 1400px)" srcset="/assets/blog/manim.gif-1400.webp"/> <img src="/assets/blog/manim.gif" class="img-fluid rounded z-depth-1" width="auto" height="auto" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>​</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="engineering"/><category term="code"/><summary type="html"><![CDATA[前言 本文适合算法研究员/工程师阅读，如果你遇到奇怪的内存泄漏问题，说不定本文能帮你找到答案，解答疑惑。 虽然在大部分场景下，程序的内存泄漏都和数据息息相关。但是读完本文你就会了解，没有被正确使用的Tensor也会导致内存和显存的泄漏。]]></summary></entry><entry><title type="html">关于炼丹，你是否知道这些细节？</title><link href="https://fatescript.github.io/blog/2021/details-of-deep-learning-engineering/" rel="alternate" type="text/html" title="关于炼丹，你是否知道这些细节？"/><published>2021-12-28T15:59:00+00:00</published><updated>2021-12-28T15:59:00+00:00</updated><id>https://fatescript.github.io/blog/2021/details-of-deep-learning-engineering</id><content type="html" xml:base="https://fatescript.github.io/blog/2021/details-of-deep-learning-engineering/"><![CDATA[<p>本文算是我工作一年多以来的一些想法和经验，最早发布在旷视研究院内部的论坛中，本着开放和分享的精神发布在我的知乎专栏中，如果想看干货的话可以直接跳过动机部分。另外，后续在这个专栏中，我会做一些关于原理和设计方面的一些分享，希望能给领域从业人员提供一些看待问题的不一样的视角。</p> <h4 id="动机">动机</h4> <p>前段时间走在路上，一直在思考一个问题：我的时间开销很多都被拿去给别人解释一些在我看起来显而易见的问题了，比如<a href="https://github.com/Megvii-BaseDetection/cvpods">cvpods</a>里面的一些code写法问题（虽然这在某些方面说明了文档建设的不完善），而这变相导致了我实际工作时间的减少，如何让别人少问一些我觉得答案显而易见的问题？如何让别人提前规避一些不必要的坑？只有解决掉这样的一些问题，我才能从一件件繁琐的小事中解放出来，把精力放在我真正关心的事情上去。</p> <p>其实之前同事有跟我说过类似的话，每次带一个新人，都要告诉他：你的实现需要注意这里blabla，还要注意那里blabla。说实话，我很佩服剑锋同学带intern的细致和知无不言，但我本性上并不喜欢每次花费时间去解释一些我觉得显而易见的问题，所以我打算写一个帖子，把我踩过的坑和留下来的经验broadcast出去。希望能够方便别人，同时也节约我的时间。</p> <p>加入旷视以来，个人一直在做一些关于框架相关的内容，所以内容主要偏向于模型训练之类的工作。因为<strong>我无法想象知识在别人脑海中的样子（the curse of knowledge），所以只能选取被问的最多的，和我觉得最应该知道的</strong>。</p> <p>准备好了的话，我们就启航出发（另，这篇blog会长期进行更新）。</p> <h4 id="坑经验">坑/经验</h4> <h5 id="data模块">Data模块</h5> <ol> <li>python图像处理用的最多的两个库是opencv和Pillow（PIL），但是两者读取出来的图像并不一样，<strong>opencv读取的图像格式的三个通道是BGR形式的，但是PIL是RGB格式的</strong>。这个问题看起来很小，但是衍生出来的坑可以有很多，最常见的场景就是数据增强和预训练模型中。比如有些数据增强的方法是基于channel维度的，比如megengine里面的<a href="https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L937">HueTransform</a>，在<a href="https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L958">这一行</a>显然是需要确保图像是BGR的，但是经常会有人只看有Transform就无脑用了，从来没有考虑过这些问题。</li> <li>接上条，RGB和BGR的另一个问题就是导致预训练模型载入后训练的方式不对，最常见的场景就是预训练模型的input channel是RGB的（例如torch官方来的预训练模型），然后你用cv2做数据处理，最后还忘了convert成RGB的格式，那么就是会有问题。这个问题应该很多炼丹的同学没有注意过，我之前写<a href="https://github.com/FateScript/CenterNet-better">CenterNet-better</a>就发现<a href="https://github.com/xingyizhou/CenterNet">CenterNet</a>存在这么一个问题，要知道当时这可是一个有着3k多star的仓库，但是从来没有人意识到有这个问题。当然，依照我的经验，如果你训练的iter足够多，即使你的channel有问题，对于结果的影响也会非常小。不过，既然能做对，为啥不注意这些问题一次性做对呢？</li> <li>torchvision中提供的模型，都是输入图像经过了ToTensor操作train出来的。也就是说最后在进入网络之前会统一除以255从而将网络的输入变到0到1之间。torchvision的<a href="https://pytorch.org/vision/stable/models.html">文档</a>给出了他们使用的mean和std，也是0-1的mean和std。如果你使用torch预训练的模型，但是输入还是0-255的，那么恭喜你，在载入模型上你又会踩一个大坑（要么你的图像先除以255，要么mean和std都要乘以255）。</li> <li>ToTensor之后接数据处理的坑。上一条说了ToTensor之后图像变成了0到1的，但是一些数据增强对数值做处理的时候，是针对标准图像，很多人ToTensor之后接了这样一个数据增强，最后就是练出来的丹是废的（心疼电费QaQ）。</li> <li>数据集里面有一个图特别诡异，只要train到那一张图就会炸显存（CUDA OOM），别的图训练起来都没有问题，应该怎么处理？通常出现这个问题，首先判断数据本身是不是有问题。如果数据本身有问题，在一开始生成Dataset对象的时候去掉就行了。如果数据本身没有问题，只不过因为一些特殊原因导致显存炸了（比如检测中图像的GT boxes过多的问题），可以catch一个CUDA OOM的error之后将一些逻辑放在CPU上，最后retry一下，这样只是会慢一个iter，但是训练过程还是可以完整走完的。</li> <li>pytorch中dataloader的坑。有时候会遇到pytorch num_workers=0（也就是单进程）没有问题，但是多进程就会报一些看不懂的错的现象，这种情况通常是因为torch到了ulimit的上限，更核心的原因是<strong>torch的dataloader不会释放文件描述符</strong>（参考<a href="https://github.com/pytorch/pytorch/issues/973">issue</a>）。可以ulimit -n 看一下机器的设置。跑程序之前修改一下对应的数值。</li> <li>opencv和dataloader的神奇联动。很多人经常来问为啥要写cv2.setNumThreads(0)，其实是因为cv2在做resize等op的时候会用多线程，当torch的dataloader是多进程的时候，多进程套多线程，很容易就卡死了（具体哪里死锁了我没探究很深）。除了setNumThreads之外，通常还要加一句cv2.ocl.setUseOpenCL(False)，原因是cv2使用opencl和cuda一起用的时候通常会拖慢速度，加了万事大吉，说不定还能加速。</li> <li>dataloader会在epoch结束之后进行类似重新加载的操作，复现这个问题的code放在后面的 code复现部分了。这个问题算是可以说是一个高级bug/feature了，可能导致的问题之一就是炼丹师在本地的code上进行了一些修改，然后训练过程直接加载进去了。解决方法也很简单，让你的sampler源源不断地产生数据就好，这样即使本地code有修改也不会加载进去。</li> </ol> <h5 id="module模块">Module模块</h5> <ol> <li> <p>BatchNorm在训练和推断的时候的行为是不一致的。这也是新人最常见的错误（类似的算子还有dropout，这里提一嘴，<strong>pytorch的dropout在eval的时候行为是Identity</strong>，之前有遇到过实习生说dropout加了没效果，直到我看了他的code： x = F.dropout(x, p=0.5) ）</p> </li> <li>BatchNorm叠加分布式训练的坑。<strong>在使用DDP（DistributedDataParallel）进行训练的时候，每张卡上的BN统计量是可能不一样的，仔细检查broadcast_buffer这个参数</strong>。DDP的默认行为是在forward之前将rank0 的 buffer做一次broadcast（broadcast_buffer=True），但是一些常用的开源检测仓库是将broadcast_buffer设置成False的（参考：<a href="https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206">mmdet</a> 和 <a href="https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206">detectron2</a>，我猜是在检测任务中因为batchsize过小，统一用卡0的统计量会掉点）<strong>这个问题在一边训练一边测试的code中更常见</strong>，比如说你train了5个epoch，然后要分布式测试一下。一般的逻辑是将数据集分到每块卡上，每块卡进行inference，最后gather到卡0上进行测点。但是<strong>因为每张卡统计量是不一样的，所以和那种把卡0的模型broadcast到不同卡上测试出来的结果是不一样的。这也是为啥通常训练完测的点和单独起了一个测试脚本跑出来的点不一样的原因</strong>（当然你用SyncBN就不会有这个问题）。</li> <li>Pytorch的SyncBN在1.5之前一直实现的有bug，所以存在使用SyncBN结果掉点的问题。</li> <li>用了多卡开多尺度训练，明明尺度更小了，但是速度好像不是很理想？这个问题涉及到多卡的原理，因为分布式训练的时候，在得到新的参数之后往往需要进行一次同步。假设有两张卡，卡0的尺度非常小，卡1的尺度非常大，那么就会出现卡0始终在等卡1，于是就出现了虽然有的尺度变小了，但是整体的训练速度并没有变快的现象（木桶效应）。解决这个问题的思路就是<strong>尽量把负载拉均衡一些</strong>。</li> <li>多卡的小batch模拟大batch（梯度累积）的坑。假设我们在单卡下只能塞下batchsize = 2，那么为了模拟一个batchsize = 8的效果，通常的做法是forward / backward 4次，不清理梯度，step一次（当然考虑BN的统计量问题这种做法和单纯的batchsize=8肯定还是有一些差别的）。在多卡下，因为调用loss.backward的时候会做grad的同步，所以说前三次调用backward的时候需要加ddp.no_sync的context manager（不加的话，第一次bp之后，各个卡上的grad此时会进行同步），最后一次则不需要加。当然，我看很多仓库并没有这么做，我只能理解他们就是单纯想做梯度累积（BTW，加了<a href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync">ddp.no_sync</a>会使得程序快一些，毕竟加了之后bp过程是无通讯的）。</li> <li><strong>浮点数的加法其实不遵守交换律的</strong>，这个通常能衍生出来GPU上的运算结果不能严格复现的现象。可能一些非计算机软件专业的同学并不理解这一件事情，直接自己开一个python终端体验可能会更好：</li> </ol> <figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="nf">print</span><span class="p">(</span><span class="mf">1e100</span> <span class="o">+</span> <span class="mf">1e-4</span> <span class="o">+</span> <span class="o">-</span><span class="mf">1e100</span><span class="p">)</span>  <span class="c1"># ouptut: 0
</span><span class="nf">print</span><span class="p">(</span><span class="mf">1e100</span> <span class="o">+</span> <span class="o">-</span><span class="mf">1e100</span> <span class="o">+</span> <span class="mf">1e-4</span><span class="p">)</span>  <span class="c1"># output: 0.0001</span></code></pre></figure> <h5 id="训练模块">训练模块</h5> <ol> <li>FP16训练/混合精度训练。使用Apex训练混合精度模型，在保存checkpoint用于继续训练的时候，除了model和optimizer本身的state_dict之外，还需要保存一下amp的state_dict，这个在<a href="https://nvidia.github.io/apex/amp.html#checkpointing">amp的文档</a>中也有提过。（当然，经验上来说忘了保存影响不大，会多花几个iter search一个loss scalar出来）</li> <li>多机分布式训练卡死。 @zhangsongyang 遇到的一个坑。场景是rlaunch申请了两个8卡机，然后机器1和机器2用前4块卡做通讯（local rank最大都是4）。可以初始化process group，在使用DDP的时候会卡死。原因在于pytorch在做DDP的时候会猜测一个rank，参考<a href="https://github.com/pytorch/pytorch/blob/0d437fe6d0ef17648072eb586484a4a5a080b094/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1622-L1630">code</a>。对于上面的场景，第二个机器上因为存在卡5到卡8，而对应的rank也是5到8，所以DDP就会认为自己需要同步的是卡5到卡8，于是就卡死了。</li> <li>在使用AMP的时候，使用Adam/AdamW优化器之后NaN，之前没有任何异常现象，通常是optimizer里面的eps的问题，调整一下eps的数值就好了（比如1e-3），因为默认的eps是1e-8，在fp16下浮点运算容易出NaN</li> <li><strong>梯度为0</strong> 和 <strong>参数是否更新</strong> 没有必然关系。因为grad并不是最终的参数更新量，最终的参数更新量是在optimizer里面进行计算的。一个最简单的例子就是设置了weight decay不为0，当optimizer的weight decay不为0 的时候，最终的参数更新量都会加上 <code class="language-plaintext highlighter-rouge">lr * wd * param</code> ，所以 grad为0并不等价于参数量不会更新。一些可以refer的<a href="https://github.com/MegEngine/MegEngine/blob/d404ed184d/imperative/python/megengine/optimizer/sgd.py#L72-L73">code</a>（此处以megengine为例，pytorch仅仅是把逻辑写成了cpp来加速）</li> </ol> <h4 id="复现code">复现Code</h4> <h5 id="data部分">Data部分</h5> <figure class="highlight"><pre><code class="language-python" data-lang="python"><table class="rouge-table"><tbody><tr><td class="gutter gl"><pre class="lineno">1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
</pre></td><td class="code"><pre><span class="kn">from</span> <span class="n">torch.utils.data</span> <span class="kn">import</span> <span class="n">DataLoader</span>
<span class="kn">from</span> <span class="n">torch.utils.data</span> <span class="kn">import</span> <span class="n">Dataset</span>
<span class="kn">import</span> <span class="n">tqdm</span>
<span class="kn">import</span> <span class="n">time</span>


<span class="k">class</span> <span class="nc">SimpleDataset</span><span class="p">(</span><span class="n">Dataset</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="mi">400</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">length</span> <span class="o">=</span> <span class="n">length</span>
        <span class="n">self</span><span class="p">.</span><span class="n">data_list</span> <span class="o">=</span> <span class="nf">list</span><span class="p">(</span><span class="nf">range</span><span class="p">(</span><span class="n">length</span><span class="p">))</span>

    <span class="k">def</span> <span class="nf">__getitem__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span>
        <span class="n">data</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">data_list</span><span class="p">[</span><span class="n">index</span><span class="p">]</span>
        <span class="n">time</span><span class="p">.</span><span class="nf">sleep</span><span class="p">(</span><span class="mf">0.1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">data</span>

    <span class="k">def</span> <span class="nf">__len__</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">self</span><span class="p">.</span><span class="n">length</span>


<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">local_rank</span><span class="p">):</span>
    <span class="n">dataset</span> <span class="o">=</span> <span class="nc">SimpleDataset</span><span class="p">()</span>
    <span class="n">dataloader</span> <span class="o">=</span> <span class="nc">DataLoader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_workers</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
    <span class="n">iter_loader</span> <span class="o">=</span> <span class="nf">iter</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span>
    <span class="n">max_iter</span> <span class="o">=</span> <span class="mi">100000</span>
    <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">.</span><span class="nf">tqdm</span><span class="p">(</span><span class="nf">range</span><span class="p">(</span><span class="n">max_iter</span><span class="p">)):</span>
        <span class="k">try</span><span class="p">:</span>
            <span class="n">_</span> <span class="o">=</span> <span class="nf">next</span><span class="p">(</span><span class="n">iter_loader</span><span class="p">)</span>
        <span class="k">except</span> <span class="nb">StopIteration</span><span class="p">:</span>
            <span class="nf">print</span><span class="p">(</span><span class="sh">"</span><span class="s">Refresh here !!!!!!!!</span><span class="sh">"</span><span class="p">)</span>
            <span class="n">iter_loader</span> <span class="o">=</span> <span class="nf">iter</span><span class="p">(</span><span class="n">dataloader</span><span class="p">)</span>
            <span class="n">_</span> <span class="o">=</span> <span class="nf">next</span><span class="p">(</span><span class="n">iter_loader</span><span class="p">)</span>
            

<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="sh">"</span><span class="s">__main__</span><span class="sh">"</span><span class="p">:</span>
    <span class="kn">import</span> <span class="n">torch.multiprocessing</span> <span class="k">as</span> <span class="n">mp</span>
    <span class="n">mp</span><span class="p">.</span><span class="nf">spawn</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(),</span> <span class="n">nprocs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">daemon</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</pre></td></tr></tbody></table></code></pre></figure> <p>当程序运行起来的时候，可以在Dataset里面的__getitem__里面加一个print，在refresh之后，就会print内容（看到现象是不是觉得自己以前炼的丹可能有问题了呢）。</p> <h4 id="碎碎念">碎碎念</h4> <p>一口气写了这么多条也有点累了，后续有踩到新坑的话我也会继续更新这篇文章的。毕竟写这篇文章是希望工作中不再会有人踩类似的坑 &amp; 炼丹的人能够对深度学习框架有意识（虽然某种程度上来讲这算是个心智负担）。</p> <p>如果说今年来什么事情是最大的收获的话，那就是理解了一个开放的生态是可以迸发出极强的活力的，也希望能看到更多的人来分享自己遇到的问题和解决的思路。毕竟探索的答案只是一个副产品，过程本身才是最大的财宝。</p>]]></content><author><name></name></author><category term="deep-learning"/><category term="engineering"/><category term="code"/><summary type="html"><![CDATA[本文算是我工作一年多以来的一些想法和经验，最早发布在旷视研究院内部的论坛中，本着开放和分享的精神发布在我的知乎专栏中，如果想看干货的话可以直接跳过动机部分。另外，后续在这个专栏中，我会做一些关于原理和设计方面的一些分享，希望能给领域从业人员提供一些看待问题的不一样的视角。]]></summary></entry></feed>