深度學習優化(hua)器算法巧思速覽
這一(yi)篇博(bo)文想寫很久了,一(yi)直沒有下筆,核心(xin)原因也是有一(yi)些(xie)待辦的(de)思(si)路(lu)在攻關驗(yan)證。
我(wo)們先(xian)從(cong)一個核(he)心的問題出發,
1. 為什么要研究優化器算法?
它的關聯問題:訓練為什么要調參(can)(can),調的是什么參(can)(can)?
如(ru)果(guo)就(jiu)這個問(wen)題去(qu)問(wen)各種大語言模(mo)型,它們能(neng)給出一堆的(de)理由。
但(dan)就博主(zhu)而言(yan),答案只有一個(ge):
干掉調參,解放生產力,榨干算力。
說到底就一個字"窮"。
在多(duo)年的(de)(de)研發生(sheng)涯里,對調參這個事深惡痛(tong)絕,為什么辛辛苦苦架(jia)構出來的(de)(de)模型(xing),一訓練就崩(beng),訓練收(shou)斂慢到龜速,這嚴重影(ying)響了開發進度,并且(qie)增加了很多(duo)不可抗力的(de)(de)消耗。
我相信有很(hen)(hen)多業(ye)內同行(xing),都有這種痛,訓練了很(hen)(hen)久,效果(guo)依(yi)舊(jiu)很(hen)(hen)差,泛化能力(li)也不行(xing),然后就開(kai)始(shi)苦惱,為什么自己沒有足夠的錢,足夠的算力(li)。
明明自己很好的思路,戛然而止,退而求其次。
早年間,博主(zhu)經常(chang)半夜醒來,看訓練的損失曲線,生(sheng)怕訓崩(beng)。就算沒有訓崩(beng),自己花費了(le)大量時間精力,卻沒有很好的回報(bao)。
一(yi)次又一(yi)次,是(shi)很打(da)擊信心的。
在付出了(le)(le)大量(liang)時間和人民幣之(zhi)后,博主終于從泥潭里爬出來了(le)(le),時光荏苒,這個困擾我九年(nian)的問題,畫上句號了(le)(le)。
那大(da)語(yu)言(yan)模型是怎么回答這個(ge)問題的。
核心就一句話:
"沒有新優化器,下一代模型根本訓不起來。"
-
從理論上看,它是在解(jie)決(jue)一(yi)個尚未被完全理解(jie)的(de)復雜高維優化問題,充滿挑(tiao)戰(zhan)與機遇。
解決基(ji)礎性訓練(lian)難題——讓模型"能學(xue)"
-
從工程上看,它是降低(di)AI研發成(cheng)本、推(tui)動技術普(pu)及的關鍵杠桿。
追求極致的效(xiao)率與效(xiao)益(yi)——讓模型"快學"且"省學"
-
從性能上看,它是提升模型最(zui)終準確性(xing)、魯棒性(xing)和泛化能力的決定性(xing)因素。
提升(sheng)模型的終極性能——讓(rang)模型"學好"
最終達到,拓展AI的技術邊界——讓"不可(ke)(ke)能"成為(wei)"可(ke)(ke)能"
當然就這(zhe)個問(wen)題,大(da)家可(ke)以自(zi)行去追(zhui)問(wen)各家的大(da)語言模型,給出的結論大(da)同小異。
2. 那博主為什么要寫這篇博文?
最基本的(de)(de)還是希望(wang)拋(pao)磚(zhuan)引(yin)玉,希望(wang)能有更(geng)多的(de)(de)同行在力大磚(zhuan)飛,燒錢的(de)(de)當下,不要放棄底層算法的(de)(de)研究。
同時為更多的(de)深度學習小(xiao)白提供一個新的(de)視角,學習并應用(yong)深度學習,溫故而知新。
3. 那什么是優化器算法?
優化(hua)器(qi)算法是驅動機器(qi)學(xue)習模型學(xue)習的(de)"引擎(qing)"。它的(de)核(he)心(xin)任(ren)務是:在(zai)訓練(lian)過程中,根(gen)據損(sun)(sun)失函(han)數計(ji)算出的(de)梯度(即(ji)方向),以某種策略更(geng)新模型的(de)參數,從而最小化(hua)損(sun)(sun)失函(han)數。
可以將訓練(lian)過(guo)程(cheng)想象(xiang)成(cheng)在復雜地形中尋找最低(di)點:
- 損失函數:代表地形的高度。
- 模型參數:代表我們在地形中的位置。
- 梯度:代表我們腳下最陡峭的下坡方向。
- 優化器:就是那個決定"往哪個方向走、走多大步、以及是否要考慮之前的慣性"的導航策略。
Adam (Adaptive Moment Estimation)
-
思想:目前最流行和默(mo)認的優化(hua)器之(zhi)一。它(ta)結合了Momentum和RMSProp的優點。
- 它計算梯度的一階矩(均值,提供動量)和二階矩(未中心化的方差,用于自適應調整學習率)。
- 然后對這兩個矩進行偏差校正,使其在訓練初期不那么偏向于0。
-
優點:
- 通常收斂速度快。
- 對超參數的選擇相對魯棒(默認參數通常就能工作得很好)。
- 能處理噪聲和稀疏梯度。
如果把Adam的(de)一階矩(ju)和(he)二階矩(ju)去掉,它就蛻變為SGD。
而隨機梯度下降(樸素SGD)是一(yi)種優化(hua)算法,通過隨機選取單個(ge)樣本來近似梯(ti)度,從而迭代更新(xin)模型參數(shu),收斂至最小值。
換句話說,樸素(su)SGD是一個(ge)沒有應用(yong)任何(he)先驗(yan)補充的(de)野蠻人,較于Adam的(de)平(ping)滑(hua)學(xue)習而言,它(ta)就像(xiang)一只無(wu)頭蒼蠅,到處亂撞,也不知道該撞多少次才能收斂至最小值。
4. Adam相較于樸素SGD,它做了哪些改進?
-
引入動量緩沖m,也就是一(yi)階矩(ju),指數(shu)加權平滑梯度,它積累了歷史梯度的方向趨(qu)勢。使得(de)樸素SGD的動蕩趨(qu)于平穩平滑。
-
引入自適應步長v,也就是二階(jie)矩,指數加(jia)權(quan)平均(jun)的平方,它積累了歷史梯度(du)平方的值(zhi)趨勢。
最終以 grad = m / sqrt(v) 作為(wei)目標梯度(du)進行更新(xin)。
對于動量一(yi)階矩,基本(ben)沒啥好(hao)說的,就是求歷史平(ping)均梯度,使得訓練平(ping)穩。
核心還是自適應步(bu)(bu)長v,對于頻繁更新、梯度大的參(can)數(shu),其(qi)二階矩估(gu)計值(zhi)大,因此實(shi)際更新步(bu)(bu)長會被調(diao)小(除以一個大數(shu)),避免"步(bu)(bu)子太大"而越過最優點(dian)。
對于不頻(pin)繁更新(xin)、梯度小的(de)參數,則給(gei)予更大的(de)相(xiang)對步長,鼓勵其更新(xin)。
所以Adam能加速較于樸素SGD訓練收(shou)斂(lian),二階矩功不(bu)可(ke)沒。
原本故事(shi)到這里(li),就接近(jin)完結(jie)了。
在真(zhen)實的場景下,我們發現Adam還是不夠(gou)好。
但它(ta)的(de)普及使得深度學習遍(bian)地(di)開花。
雖然仍是需要調參,但是不(bu)像之前那么"玄學(xue)"了。
當然(ran)在(zai)一(yi)些場景下(xia),例如GAN的(de)訓(xun)練,仍然(ran)有所爭議(yi)。
在博主的(de)實測(ce)下(xia),此文提(ti)及(ji)的(de)nSGDA確實比樸素(su)SGD穩(wen)健一些。
class nSGDA(torch.optim.Optimizer):
def __init__(
self,
params, # Model parameters
lr: Union[float, torch.Tensor] = 4e-5, # Learning rate (default: 4e-5)
# Coefficients used for computing running averages of gradient (default: 0.9)
momentum: float = 0.9,
# eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
eps: float = 1e-8,
weight_decay: float = 1e-2, # Weight decay (L2 penalty) (default:1e-2)
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if momentum < 0.0 or momentum >= 1.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight decay: {}".format(weight_decay))
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
eps=eps)
super().__init__(params, defaults)
def step(self, closure=None):
r"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
momentum = group['momentum']
lr = group['lr']
weight_decay = group['weight_decay']
eps = group['eps']
one_minus_momentum = 1.0 - momentum
for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError(
"current optimizer does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["m"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
m = state['m']
bias_correction = 1.0 - momentum ** state["step"]
if weight_decay != 0:
p.grad = p.grad.add(p.data, alpha=weight_decay)
m.mul_(momentum).add_(p.grad, alpha=one_minus_momentum)
step_size = lr / torch.norm(m.div(bias_correction)).add_(eps).mul_(bias_correction)
p.data.add_(m, alpha=-step_size)
return loss
當你采用Adam調參訓練,總是跑(pao)崩或者(zhe)無(wu)法收斂,這個時候(hou),稍微嘗試一(yi)下(xia)nSGDA也未嘗不可。
而Adam二階(jie)矩的(de)存在也實實在在埋了一(yi)個雷 : “過沖”問題
本(ben)來“對于(yu)不頻(pin)繁更(geng)新、梯度(du)小的(de)(de)參數,則(ze)給予(yu)更(geng)大的(de)(de)相對步(bu)長,鼓勵其更(geng)新。”
是個很好的想法,
但(dan)是(shi)有(you)一個特例,那就是(shi)訓(xun)練到后(hou)期,梯度理(li)論上(shang)也會(hui)越來越小(xiao),這個時(shi)候也不應該鼓(gu)勵其更(geng)新。
有(you)可能一(yi)更新,跑飛了,這就是后來為什么存在早(zao)停(ting)(Early Stopping)策略(lve)的(de)根由之一(yi)。
如果繼續訓練(lian),有可(ke)能從次優解里(li)爬(pa)出來(lai),但是(shi)更多(duo)實(shi)際情(qing)況是(shi),若這里(li)就是(shi)最優解,
由(you)于激(ji)進(jin)地更新,反而會越(yue)跑(pao)越(yue)遠。
理想(xiang)的情況肯定(ding)是,訓練(lian)到最(zui)(zui)優解。最(zui)(zui)后停在最(zui)(zui)優解上,或(huo)者在最(zui)(zui)優解周(zhou)圍轉(zhuan)圈。
但這里有個悖論,
你憑什么(me)認為這里(li)是(shi)最優(you)解,而(er)不是(shi)次(ci)優(you)解,這個標準(zhun)怎么(me)界定判斷(duan)。
而(er)且由(you)于數(shu)據(ju)的稀(xi)缺性,我們希望模型在這(zhe)種情況(kuang)下,還能(neng)有更強大的泛化能(neng)力,即使它(ta)沒見過的數(shu)據(ju),也能(neng)適配到位。
也就是說,
理想上我們(men)既希望能求(qiu)到解的思路規律(lv),最(zui)好(hao)覆蓋更多的求(qiu)解路徑(jing),而不是一條最(zui)短的求(qiu)解路徑(jing)。
繞路沒(mei)問題,只(zhi)要(yao)這個(ge)繞路方式能(neng)提升泛化能(neng)力。
這就(jiu)是后(hou)來dropout盛行的原因之(zhi)一(yi),因為簡單有效。
讓一(yi)部分(fen)神經元(yuan)失(shi)活(huo),也(ye)能求(qiu)到解。
但(dan)是dropout這個技術(shu)思路,慎(shen)用,用得不好,反而會起反作用。
路漫漫其修遠兮(xi),一起努力(li)吧~
5. 后Adam家族時代,百家爭鳴
由于這(zhe)個話題(ti)展開(kai),真的(de)可以寫一(yi)本書了。
所以本文(wen)的核心(xin)是"速覽",博主(zhu)帶著大家看(kan)一看(kan)這后Adam的各種巧思。
相(xiang)關的算法(fa)實現,可(ke)以參考以下(xia)項目倉庫(ku):
PyTorch:
TensorFlow/Keras:
本文沒有提及(ji)的其他算法,自(zi)行(xing)移(yi)步查閱。
5.1 砍Adam的顯存
由于一(yi)階矩(ju)m和二階矩(ju)v都需要(yao)歷史(shi)平滑(hua),所(suo)以Adam至少要(yao)占用(yong)兩(liang)倍的可訓練模型參(can)數。
這樣一來,只要(yao)模型參數(shu)一大,那訓練的時候(hou) 1+2 = 3 至(zhi)少(shao)要(yao)存儲三(san)份(fen)權重。顯(xian)存很(hen)快就不夠(gou)用(yong)了。
所(suo)以(yi),針對這個問題,我(wo)們開始磨刀霍(huo)霍(huo)向二階矩v。
5.1.1 18年的Adafactor
社區(qu)比較知(zhi)名(ming)的實現(xian):
5.1.2 19年的SM3
官方實現:
Adafactor和(he)SM3都是(shi)分解近(jin)似的(de)做法。SM3的(de)實現較為(wei)復(fu)雜,所以基本上沒有(you)被推廣(guang)開(kai)來(lai)。所以很長一段(duan)時間都是(shi)Adafactor是(shi)主流。
但是Adafactor的實現稍微有些問(wen)題。
問題函數:
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# //github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
_approx_sq_grad 這(zhe)個實現丟失了不少精度。
博(bo)主認為比(bi)較合理(li)的實現,是把sqrt放到最后計(ji)算,精度會高些(xie)。
@staticmethod
def _approx_sq_grad(row_exp_avg_sq, col_exp_avg_sq):
row_factor = row_exp_avg_sq.unsqueeze(-1)
row_factor = row_factor.mean(dim=-2, keepdim=True).div(row_factor)
col_factor = col_exp_avg_sq.unsqueeze(-2)
return row_factor.div(col_factor).sqrt_()
5.1.3 22年的Amos
在Adafactor和SM3之后很長一段時間,砍優化器顯存占用(yong)這(zhe)個事情(qing)似乎(hu)被遺忘了。
直到Amos的(de)(de)出現(xian),它進一步砍(kan)掉了v的(de)(de)顯存(cun)占用,直接采用了平方均值,美其名曰(yue)"信(xin)息共享"。
顯(xian)存不夠用,又想保住精度,可(ke)以考慮采用Amos,當然它較之Adam還有不少改(gai)進(jin)點(dian)。
5.1.4 24年損失作為學習率的奇思妙想
利用損失值(loss)本身來動態調整優化器的學習率,以此作為替代二階v實現更快(kuai)的收斂(lian)。
非常簡單的思路: “損失越大(da),學習(xi)率越大(da);損失越小(xiao),學習(xi)率越小(xiao)。”
由(you)于論文(wen)沒有給出開(kai)源實現,也沒有搜到第(di)三方實現。
參考論文(wen)的思想(xiang),實(shi)現了(le)該思路,代碼實(shi)現不(bu)完全對應論文(wen)內(nei)容,僅供參考學(xue)習。
# mypy: allow-untyped-defs
from typing import Tuple, Union
import torch
from torch import GradScaler
class AdaLo(torch.optim.Optimizer):
r"""
AdaLo: Adaptive Learning Rate Optimizer with Loss for Classification
paper: //www.sciencedirect.com/science/article/abs/pii/S0020025524015214
code: //github.com/cpuimage/AdaLo
usage:
for inputs, labels in dataloader:
def closure(inp=inputs, lbl=labels):
optimizer.zero_grad()
loss = criterion(model(inp), lbl)
loss.backward()
return loss
optimizer.step(closure)
Args:
params: Iterable of parameters to optimize or dicts defining
parameter groups.
lr: Learning rate (not used for step size calculation due to the adaptive learning rate mechanism; retained solely for API consistency)
betas: (beta1, beta2) coefficients for gradient momentum and loss-EMA smoothing respectively
weight_decay: L2 weight decay
kappa: loss scaling factor
eps: float. term added to the denominator to improve numerical stability.
mode: control learning rate adaptation mode ('adversarial' or 'compliant')
'adversarial': decrease learning rate when loss increases (conservative strategy)
'compliant': increase learning rate when loss increases (aggressive strategy)
"""
def __init__(self,
params,
lr: Union[float, torch.Tensor] = 1e-8,
betas: Tuple[float, float] = (0.9, 0.999),
weight_decay: float = 1e-2,
kappa: float = 3.0,
eps: float = 1e-8,
mode: str = 'adversarial'):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if betas[0] < 0.0 or betas[0] >= 1.0:
raise ValueError("Invalid beta1 value: {}".format(betas[0]))
if betas[1] < 0.0 or betas[1] >= 1.0:
raise ValueError("Invalid beta2 value: {}".format(betas[1]))
if weight_decay < 0.0:
raise ValueError("Invalid weight decay: {}".format(weight_decay))
defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], weight_decay=weight_decay, kappa=kappa,
mode=mode, eps=eps)
super(AdaLo, self).__init__(params, defaults)
def step(self, closure=None, scaler: GradScaler = None, loss=None):
already_updated_by_scaler = False
if closure is not None:
with torch.enable_grad():
loss = closure()
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(self)
scaler.step(self, loss=loss)
scaler.update()
already_updated_by_scaler = True
if not already_updated_by_scaler:
for group in self.param_groups:
beta1 = group['beta1']
beta2 = group['beta2']
weight_decay = group['weight_decay']
kappa = group['kappa']
mode = group['mode']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("current optimizer does not support sparse gradients")
state = self.state[p]
if len(state) == 0:
state['m'] = torch.zeros_like(p.data)
state['loss_ema'] = torch.tensor(0.0, device=p.device, dtype=p.dtype)
m = state['m']
loss_ema = state['loss_ema']
m.lerp_(p.grad, 1.0 - beta1)
if loss is not None:
scaled_loss = torch.log1p(loss.detach())
transformed_loss = (torch.tanh(-scaled_loss * 0.5) + 1.0) * 0.5
loss_ema.lerp_(transformed_loss, 1.0 - beta2)
if mode == 'adversarial':
lr_t = loss_ema.div(kappa).clamp_min_(eps)
else:
lr_t = (1.0 - loss_ema).div(kappa).clamp_min_(eps)
if weight_decay != 0:
p.data.mul_(1.0 - lr_t * weight_decay)
p.data.sub_(m * lr_t)
return loss
在一些場景下(xia)實測也是(shi)很穩健,lr = v = loss 不(bu)得不(bu)夸一下(xia)論文(wen)原作(zuo)者的奇(qi)思妙想。
PyTorch官方使用amp混合精度的時(shi)候,GradScaler.step里(li)有這么一句。
if "closure" in kwargs:
raise RuntimeError(
"Closure use is not currently supported if GradScaler is enabled."
)
也就是說閉包和amp混合當前不支(zhi)持一(yi)起(qi)用。
在AdaLo代碼倉(cang)庫里(li),博主演示(shi)怎么魔(mo)改實(shi)現(xian)閉(bi)包和(he)amp可(ke)以(yi)同時使用(yong),感興(xing)趣的(de)可(ke)以(yi)閱讀具體實(shi)現(xian)。
在實測(ce)過程(cheng)中,發(fa)現(xian) “損失(shi)越大(da),學習率越大(da);損失(shi)越小,學習率越小。”
這個(ge)做法(fa)在(zai)一些場景(jing)下比較激(ji)進,所以增加了(le)一個(ge)新的參數(shu)為mode可(ke)切換(huan)學習率適配模(mo)式,默認設為保守模(mo)式。
分別對應
- adversarial (保守模式):“損失越大,學習率越小;損失越小,學習率越大。”
- compliant (激進模式) :“損(sun)失越(yue)大,學(xue)習(xi)(xi)率(lv)越(yue)大;損(sun)失越(yue)小(xiao),學(xue)習(xi)(xi)率(lv)越(yue)小(xiao)。”
5.1.5 窮到極致,什么都能接受
如果顯存極度匱乏,手頭(tou)還(huan)挺緊,能訓練(lian)比什么都重要的話(hua)。
采用 非(fei)負矩陣分解(NNMF),將梯度權重轉換(huan)為(wei)最接近正方形的矩陣,分解為(wei)行列(lie)兩(liang)個向量。
雖然是有損的壓(ya)(ya)縮解壓(ya)(ya)操(cao)作,但在一些特定(ding)的場景能(neng)減少可觀的內存占用,在內存效(xiao)率和(he)優(you)化(hua)性能(neng)之間取得相對平(ping)衡。
核心算法如下:
@torch.no_grad()
def _unnmf(self, row_col: tuple) -> torch.Tensor:
return torch.outer(row_col[0], row_col[1])
@torch.no_grad()
def _nnmf(self, matrix: torch.Tensor, out) -> tuple:
shape = matrix.shape
torch.sum(matrix, dim=1, out=out[0])
torch.sum(matrix, dim=0, out=out[1])
if shape[0] < shape[1]:
scale = out[0].sum()
if scale != 0:
torch.div(out[0], scale, out=out[0])
else:
scale = out[1].sum()
if scale != 0:
torch.div(out[1], scale, out=out[1])
return out
5.2 Adam二階矩v為0的問題
導致v為(wei)0有(you)很多原因,在模(mo)型訓練的不同階段,由于噪聲也好(hao),精度也好(hao),會(hui)直接或(huo)者間接導致v為(wei)0。
前面提到 grad = m / sqrt(v)
早期Adam論文里的解決(jue)方案就是直接給v加上一個(ge)epsilon,一般(ban)設(she)為1e-8,避免除以0。
而后續經過不少團隊(dui)的(de)實(shi)踐發現這么做有點魯莽。
然(ran)后就有(you)人(ren)開始針對這(zhe)個問題進(jin)行修改。
但是林林總總,都是把epsilon移來移去,例(li)如梯度平方后就加上epsilon,再(zai)進行指數加權平均。
也有(you)采(cai)用softplus抑制分母(mu)過(guo)小的做法:
grad = m / softplus(sqrt(v))
這(zhe)個問題一直(zhi)到(dao)了(le)2024年,有新(xin)的進展。
方法很簡單,刪除epsilon,采用atan2。
grad = atan2(m, sqrt(v))
從(cong)數值穩定的(de)角度來說(shuo),atan2確實是穩定了(le)許多(duo),而且基本規避了(le)一些特殊情(qing)況下訓練(lian)跑崩,導致損失為nan的(de)情(qing)況。
Adam的betas默認(ren)參(can)數(shu)是(0.9,0.999) ,也(ye)有人覺得(de)這里也(ye)存在調參(can)適(shi)配(pei)問題(ti)。
刪(shan)除epsilon一般(ban)都可以理解,但(dan)把動量參數也干掉,做(zuo)成(cheng)自適應的(de)"膽(dan)大(da)妄為(wei)",也是挺(ting)絕的(de)。
不管成不成功,效果幾何,就這(zhe)魄(po)力,值(zhi)得我在此一提。
5.3 Adam的梯度長尾問題
這個很好理解,由于一階矩(ju)m和二階矩(ju)v都采用了指數平均,在不同程度(du)上(shang)也是導致梯度(du)長尾(wei)的誘(you)因(yin)之一。
因為求平(ping)均值(zhi)這個事(shi),就跟(gen)奧運比賽(sai)打分(fen)一樣,只用(yong)均值(zhi)很不(bu)公平(ping)。去掉一個最高分(fen),去掉一個最低分(fen),然后再算平(ping)均相對合(he)理(li)一些。
求(qiu)損(sun)失均(jun)值(zhi)的時候(hou)一樣(yang)存在(zai),博主(zhu)曾經設想過,也許求(qiu)損(sun)失的中位數是一個可行的做法,但也有一定的局限性(xing)。
沒有經(jing)過嚴(yan)格驗證的(de)求損失(shi)中位(wei)數思路的(de)實(shi)現,僅供參考:
def soft_median(losses, temperature=None):
if temperature is None:
temperature = max(0.1, 0.5 * losses.std())
if losses.numel() % 2 == 0:
losses = torch.cat([losses, losses.new_zeros(1)])
x_sorted, _ = torch.sort(losses)
n_loss = losses.shape[0]
median_idx = (n_loss - 1) * 0.5
idxs = torch.arange(n_loss, device=losses.device, dtype=losses.dtype)
weights = torch.softmax(-torch.abs(idxs - median_idx) / temperature, dim=0)
return torch.dot(weights, x_sorted)
同樣的,梯度(du)在訓練過(guo)程(cheng)中變化很大,一些長(chang)尾樣本帶(dai)來的貢獻就會被淹沒掉。
帶來的后果,不是(shi)過(guo)擬合,就(jiu)是(shi)泛化(hua)差(cha),能拿到次優解那是(shi)屬于幸運(yun)兒了。
這(zhe)個方向的研究多,也不(bu)多,因(yin)為很(hen)多長尾問題基本上不(bu)會考慮在優(you)化器里解決,一般會采用損失加(jia)權懲罰的思路來緩(huan)解。
這篇論文可以幫(bang)助進一步理解梯(ti)度長尾問題。
當然它不(bu)是一個(ge)主流的方案和思路,主流的方案更多的是采(cai)用(yong)元學習之類的做法,局限性也比較大。
那該如何直觀地洞察梯度長尾呢?
采用(yong)TensorBoard,對(dui)參數和梯度進行可(ke)視化,查(cha)看其直方圖,非常直觀(guan)。
示例如下:
參數直方圖:

從(cong)參(can)數權重的分布來看(kan),藍色左邊(bian)一(yi)直在拖尾(wei),紅色的左邊(bian)尾(wei)巴(ba)開(kai)始右移聚攏(long)。從(cong)參(can)數來看(kan),可以看(kan)到一(yi)些趨勢,但不(bu)夠(gou)直觀。
我們再來(lai)看其對應(ying)的梯度直方圖:

這就一目了然,左邊藍色(se)明顯存在(zai)梯度長尾(wei),而右邊紅色(se)的(de)梯度長尾(wei)逐漸(jian)開始消失,且(qie)紅色(se)更趨向于正態分(fen)布。
我們再看另一組圖:
![]() |
|
這是vae潛空(kong)間0-9十個(ge)數字的(de)聚(ju)類圖(tu)。
相關vae代碼示例(li)見:
圖二整(zheng)體聚(ju)合接(jie)近(jin)(jin)一個圓圈,而圖一接(jie)近(jin)(jin)橢圓。
這兩種情況(kuang),是(shi)圖二(er)還是(shi)圖一的模型(xing)權重泛化能力更勝(sheng)一籌呢(ni)。
答案是圖二,它的kl散(san)度損失(shi)更低。
真(zhen)實(shi)情境(jing)下長(chang)尾也(ye)可以(yi)是噪(zao)聲或(huo)標簽錯誤,所以(yi)擬合長(chang)尾也(ye)不(bu)是完(wan)全是一件好事情。
一切以實(shi)測效果為準,長尾梯度只是一個僅供參(can)考項(xiang)。
博主一(yi)直認為如果可(ke)以優雅解決長(chang)尾問題,那是新(xin)一(yi)輪(lun)的(de)曙(shu)光。
5.4 Adam的過擬合問題
由于Adam本身的(de)機(ji)制問題,
訓練損失下降極快 → 模型迅(xun)速進入插值(interpolation)區域(yu) → 參(can)數范數容易膨脹 → 邊界(jie)更復雜 → 泛化(hua)差。
當然長尾問題也是它導致(zhi)過擬合的原(yuan)因之一(yi)。
比較知名且使(shi)用廣(guang)泛的方案是l2正則化,即權重(zhong)衰減。
Adam 進化為 AdamW,也(ye)就是現在主流(liu)的優化器算法
它思路也是非常簡單粗暴,在每次更新時,從權重中減去一個固定的比例(weight * weight_decay),是正則也(ye)是先驗懲罰。
權(quan)重衰(shuai)減(jian)是(shi)一(yi)個很(hen)好的思路,但它帶來了一(yi)個新(xin)的問(wen)題。衰(shuai)減(jian)量設為多(duo)少才是(shi)合適(shi)的,也就(jiu)是(shi)說,懲罰力(li)度(du)該如何界定。
衰減(jian)過大,學(xue)習(xi)收斂緩慢(man),衰減(jian)過小,沒有起到作用。
隨后Scheduled (Stable) Weight Decay也被提(ti)出,但是(shi)應(ying)用不廣,鮮(xian)為人知。
它的(de)(de)思路也(ye)很簡單,通過匯(hui)總整個(ge)模型的(de)(de)參數信(xin)息(xi),按照參數權重占比(bi)估(gu)算出每一層的(de)(de)衰減(jian)權重。
而(er)有另(ling)一篇論文從另(ling)一個(ge)新穎的(de)角度提出了一個(ge)方案。
它的思路(lu)是在每次更新時(shi),從權重中減(jian)去一個單元范(fan)數權重,可以近(jin)似看做是為(wei)權重衰減(jian)提供了范(fan)數先驗(yan)。
而后,將正則化從“加性懲罰”轉變為“約束優化” Constrained Parameter Regularization (CPR)
CPR 作為替代權(quan)重衰減的(de)(de)替代方(fang)案,就是為了權(quan)重衰減的(de)(de)調參困局(ju),但請慎用(yong)。
5.5 學習率熱身與梯度裁剪
在說到(dao)Adam過擬合的時(shi)候,我們很容易就發現了一個問題。
在不(bu)同(tong)的(de)模型架(jia)構,訓(xun)練的(de)每(mei)個階段,每(mei)層權重(zhong)的(de)值域(yu)是(shi)不(bu)一(yi)樣的(de),而且這個值域(yu)隨著訓(xun)練的(de)增(zeng)加,也(ye)一(yi)直(zhi)在變化。
由于這個核心問題的(de)(de)(de)存在(zai),訓練早期梯(ti)度的(de)(de)(de)波動就會很大,這個時候(hou)通常就需(xu)要學(xue)習率(lv)調參,或者在(zai)模型內部加入歸一(yi)化層(ceng),目的(de)(de)(de)盡可能快地把每(mei)一(yi)層(ceng)的(de)(de)(de)值(zhi)域確立下(xia)來。
由此(ci)就引發出來學習(xi)率(lv)熱身以(yi)及梯度(du)裁剪相(xiang)關的思考。
學習率熱身相關的資(zi)料和論文也有很多,這里不展開細講(jiang)。
學(xue)習率規劃熱(re)身的基本邏輯都是(shi):
早(zao)期用極其小的學(xue)(xue)(xue)習率(lv)(lv)進行預熱訓練 → 中期慢(man)(man)慢(man)(man)地增大學(xue)(xue)(xue)習率(lv)(lv) → 后期再(zai)固定學(xue)(xue)(xue)習率(lv)(lv)或者慢(man)(man)慢(man)(man)減少學(xue)(xue)(xue)習率(lv)(lv)
雖然很(hen)傻,但是確實有效。
21年的(de)時候谷歌為了把歸一化層(ceng)刪掉,就提出了自適(shi)應梯度(du)裁(cai)剪方案(an)。
思路(lu)也(ye)很簡(jian)單,根據每層梯度和權重的(de)值域,按比(bi)例縮放(fang)當前(qian)的(de)梯度。
25年終于有人(ren)想要把學習率預熱刪掉(diao)。
思(si)路跟Scheduled (Stable) Weight Decay很像,只(zhi)不過(guo)這次是作(zuo)用在學習(xi)率上(shang)罷(ba)了(le)。
本質就是根(gen)據每(mei)層(ceng)權重梯(ti)度(du)比例算(suan)出(chu)來一個(ge)全局學習率的縮小率。由于每(mei)層(ceng)的激活(huo)函數(shu)不一樣,算(suan)出(chu)來一個(ge)全局縮小率,從(cong)邏(luo)輯上其實很牽(qian)強(qiang)。
當然除此(ci)之外(wai)還(huan)有(you)其他類似(si)的思(si)路,例如(ru):
梯度范數化
def gradient_normalization(grad, eps: float = 1e-8):
grad.div_(grad.norm(p=2) + eps)
層范數化縮放
def layer_norm_adaptation(grad, var):
w_norm = var.norm(p=2)
g_norm = grad.norm(p=2)
grad.mul_(torch.where(torch.greater(w_norm, 0),
torch.where(torch.greater(g_norm, 0), (w_norm / g_norm), 1.0),
1.0))
梯度中心化
def centralize_gradient(grad):
if grad.dim() > 1:
grad.data.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
林林總(zong)總(zong),大(da)同小異。
博主(zhu)根據自己的理解,也寫了個(ge)梯度(du)軟(ruan)裁剪,代碼如下。
@staticmethod
def _soft_clip(grad, var, epsilon=1e-12):
dim = None if (r := var.dim()) <= 1 else tuple(range(1, r))
var_norm = var.square().mean(dim=dim, keepdim=True).sqrt_().clamp_min_(epsilon)
grad_norm = grad.square().mean(dim=dim, keepdim=True).sqrt_().clamp_min_(epsilon)
clipped_norm = grad_norm.clamp_max(var_norm)
return grad.mul_(clipped_norm / grad_norm)
5.6 如何進一步加速訓練收斂
前面已(yi)經(jing)提到(dao)不(bu)少關于調參,穩定性問(wen)題(ti),但大多(duo)數人最(zui)關心的還是怎(zen)么加速訓練。
主(zhu)要的(de)思路,基(ji)本上就是根據(ju)上一步的(de)梯度信息,結合當前(qian)步的(de)梯度,在(zai)兩步之間求出一個合理的(de)方(fang)向,往這個方(fang)向再走(zou)一步。
這樣做有(you)個好處,就是(shi)可以結合上(shang)一步(bu)的位(wei)置進(jin)一步(bu)修正(zheng)方向,其(qi)實就是(shi)殘差(cha)加(jia)權的路(lu)子。
有(you)前后梯度交替(ti)的(de)做法(fa),自然也就有(you)參(can)數交替(ti)的(de)做法(fa)。
但這兩種做法都有一個弊端,就是需要多存(cun)一份參數,顯存(cun)又(you)要不夠用(yong)了。
當(dang)然如(ru)果不考慮(lv)顯存占(zhan)用問(wen)題,
也(ye)可以采用Grünwald-Letnikov(G-L)分(fen)數階(jie)導(dao)數,它利用分(fen)數階(jie)微積分(fen)的全局(ju)記憶,
將參(can)數(shu)更新的梯度替(ti)換為G-L分(fen)數(shu)階近(jin)似梯度,從而更好地利用(yong)過去的長期曲率信息。
在某些場景下(xia),算力充足,也是一種(zhong)選擇(ze)。
如(ru)果(guo)考慮顯存(cun)有限(xian)的話,
有一(yi)個折中的做法,Nesterov momentum,Adam升級(ji)為NAdam,它的思路也很(hen)簡單(dan)"先沿(yan)慣(guan)性走(zou)一(yi)步(bu),再看新梯度,沿(yan)修正(zheng)后的方向走(zou)",也就是從Adam的"看一(yi)步(bu)走(zou)一(yi)步(bu)"變成了"看一(yi)步(bu)想兩(liang)步(bu)"。
但是(shi)總感覺(jue)有點牽強,結合上面提到(dao)了各種巧思手段,隨即就有人想(xiang)到(dao)了梯度范(fan)數也是(shi)一種先驗。
對梯(ti)度范數進行(xing)指(zhi)數加權平(ping)均,根據這個(ge)信息,動態(tai)(tai)調整(zheng)(zheng)梯(ti)度,換言之也就(jiu)是動態(tai)(tai)調整(zheng)(zheng)學習(xi)率。
似(si)乎一切都在往更理想(xiang)的方(fang)向推進著。
看到(dao)這里,我相信有很多同學(xue)會問,加(jia)大學(xue)習(xi)率,難道不能加(jia)速訓練(lian)收斂嗎?
我(wo)的回答是(shi)(shi),能,只有一個前提條件,就是(shi)(shi)batch size足(zu)夠的大,且(qie)優(you)化器算法足(zu)夠的穩健。
因為看的(de)信息足夠多,用(yong)大學習率,直接邁大一步(bu),是(shi)肯定沒有問題(ti)的(de)。
這個(ge)博(bo)主(zhu)已經(jing)(jing)經(jing)(jing)過(guo)驗證(zheng),實測過(guo)了。
大多數(shu)情(qing)況下,我(wo)們看到訓練(lian)加速,損失飛快地降,不存在過擬合的話,絕大多數(shu)都是模型正在調(diao)整權重到對應的值域范圍。
假設你使用了Sigmoid激活函數(shu),輸入的(de)值(zhi)在 [-6,6]左右的(de)區間,對(dui)應的(de)輸出值(zhi)是(0.0025, 0.9975)。
也就是說在Sigmoid的前一層(ceng),至少是[-6,6]的值域,才有信息能(neng)往后(hou)傳(chuan)。
如果你在(zai)Sigmoid前面野蠻地采(cai)用了歸一(yi)化,卻不進行縮放加(jia)權,放大它的(de)值(zhi)。那這(zhe)個神經元基本上處于失(shi)活的(de)狀態。
所以理(li)想的情(qing)況下在進入Sigmoid前手動放大值域,也算是(shi)一(yi)種先(xian)驗,至于放大3.0,放大6.0那就看Sigmoid前一(yi)層到底做了(le)什么了(le)。
看到這里,我相信應該沒有人會問(wen)歸一化層到底應該加在哪(na)里合適(shi)了吧。
這里(li)只(zhi)是便于理解,舉(ju)了個小(xiao)例(li)子。
經(jing)常(chang)會有人問Muon這個(ge)基(ji)于矩陣正(zheng)交化(hua)的優化(hua)器,實測(ce)為(wei)什么沒有傳說中那么高效。
你(ni)都(dou)已經看到這里(li)了,Muon是個什么(me)玩(wan)意,你(ni)別跟我說(shuo),你(ni)心里(li)沒(mei)數。
以(yi)上,寫于2025.10.06。
商業轉載請聯系作者進行授權,非商業轉載請注明(ming)出處。
若有(you)各種其他問題可以通過以下方式聯系博主交流學習。
微信: Dbgmonks
QQ: 200759103
郵箱: gaozhihan@vip.qq.com
注: 不(bu)注明來意者一律拒絕(jue)。
License
This work is licensed under a .

