LESS:基于影响函数的高效指令数据选择方法,实现LLM能力定向提升
1. 项目概述LESS——为特定能力“定向投喂”数据在大型语言模型LLM的指令微调阶段我们常常面临一个经典难题数据并非越多越好而是越“对”越好。全量数据训练不仅耗费巨大的计算资源还可能因为数据中混杂了大量与目标能力无关甚至冲突的样本导致模型学习效率低下甚至发生“灾难性遗忘”。想象一下你想训练一个模型精通代码生成却用一份混杂了诗歌创作、客服对话和历史问答的数据集去微调效果可想而知。这正是普林斯顿NLP团队提出LESS方法的出发点如何从海量的指令数据中精准筛选出那些对提升模型某一项特定能力最有影响力的样本LESS的核心思想非常直观且有力它通过计算训练数据对目标验证任务即你想提升的能力的“影响力分数”来量化每个训练样本的价值。其理论基础源于经典的影响函数但LESS的创新之处在于它通过一系列巧妙的工程化设计如梯度投影、多检查点集成、LoRA高效训练将这套理论变成了一个可扩展、可实操的完整数据选择流水线。简单来说LESS帮你回答“在我这TB级的指令数据里到底哪些数据点能最有效地让我的模型在‘数学推理’或‘多语言问答’上表现更好”如果你正在为以下问题头疼那么LESS值得你深入关注1计算预算有限无法进行全量数据微调2希望针对性地提升模型在某个垂直领域或任务上的表现3担心混合数据微调会导致模型能力“平均化”或产生负迁移。接下来我将结合论文与代码拆解LESS的每一个步骤并分享在实际复现和应用中可能遇到的“坑”与技巧。2. LESS核心原理与设计思路拆解2.1 影响力估计从理论到实践的桥梁LESS方法的基石是影响函数。在统计学和机器学习中影响函数用于衡量移除或扰动一个训练数据点会对模型参数进而对模型在某个测试点上的预测产生多大影响。其数学形式通常涉及计算损失函数关于模型参数的Hessian矩阵的逆这对于现代大模型来说是完全不可行的计算负担。LESS巧妙地避开了直接计算Hessian逆这个“怪兽”。它采用了一种基于梯度内积的近似方法。核心直觉是如果一个训练样本的梯度方向与目标验证任务的梯度方向高度一致那么用这个样本训练模型就会将参数向有利于解决验证任务的方向推动即该样本对目标任务有正影响力。具体来说对于某个训练样本z_i和验证集D_val其影响力分数I(z_i, D_val)可以近似为I(z_i, D_val) ≈ -η * ∇_θ L(z_i, θ), Σ_{z_val in D_val} ∇_θ L(z_val, θ)其中η是学习率·, ·表示内积。这个公式的含义是计算训练样本损失梯度与验证集损失梯度之和的内积。内积值越大负得越少或正得越多说明该训练样本的梯度方向与验证任务的需求越“同向”影响力也就越大。注意这里有一个关键细节原始影响函数公式前通常有一个负号。在LESS的实现中它最终选择的是内积值最大的样本这对应的是最负的影响力即移除该样本会导致验证损失最大幅度增加或者说该样本对降低验证损失最有帮助。理解这一点对解读代码中的排序逻辑很重要。2.2 工程化挑战与LESS的解决方案直接将上述理论应用于LLM会面临三大挑战梯度维度灾难LLaMA-7B有70亿参数梯度是70亿维向量存储和计算内积的内存与计算开销无法承受。训练动态变化模型在训练过程中参数不断变化单个检查点下的梯度代表性不足。计算效率为海量训练数据逐个计算并存储高维梯度不现实。LESS的流水线设计正是为解决这些问题而生应对维度灾难梯度投影LESS不存储原始梯度而是将梯度投影到一个随机低维子空间例如8192维。这通过一个固定的随机投影矩阵实现能大概率保留向量间的角度关系即内积符号从而在极大降低存储成本的同时保持了影响力排序的可靠性。捕捉训练动态多检查点集成模型在训练不同阶段关注的数据特性不同。LESS在预热训练后采集多个检查点下的训练数据梯度并在计算影响力时进行加权平均从而得到一个更稳健、全面的影响力评估。提升计算效率LoRA与梯度缓存整个流程基于LoRA进行微调大幅减少了需要优化的参数量使得梯度计算和存储变得可行。同时LESS将训练数据的投影梯度预先计算并存储为“梯度数据存储库”在针对不同目标任务进行选择时只需计算一次目标任务的梯度然后进行高效的内积运算即可实现了“一次计算多次选择”。这套组合拳使得LESS从一个理论构想变成了一个能在实际LLM训练流程中无缝集成的实用工具。3. 实操部署与环境搭建详解3.1 环境配置与依赖安装LESS的代码库对环境有一定要求。以下是我在复现过程中总结的详细步骤和避坑指南。首先PyTorch的版本匹配是第一个关键点。原代码要求torch2.1.2。如果你使用的CUDA版本较新如12.1直接安装可能会失败。你需要从PyTorch官方历史版本页面找到与你的CUDA及系统匹配的安装命令。# 示例对于CUDA 12.1可以尝试安装2.1.2版本 pip install torch2.1.2 torchvision0.16.2 torchaudio2.1.2 --index-url https://download.pytorch.org/whl/cu121如果遇到兼容性问题一个更稳妥的方法是使用作者提供的requirements.txt中的版本它通常经过测试。安装完PyTorch后进入项目目录安装其他依赖。git clone https://github.com/princeton-nlp/LESS.git cd LESS pip install -r requirements.txt实操心得强烈建议在安装前创建一个新的conda或venv虚拟环境。这能避免与本地其他项目的包版本冲突。我曾因环境中transformer库版本不一致导致梯度计算脚本报出难以追踪的形状错误。最后以可编辑模式安装less包本身这样你可以直接修改源码并且任何导入less模块的脚本都能使用最新代码。pip install -e .3.2 数据准备与预处理LESS实验使用了四个指令微调数据集Flan v2, CoT, Dolly, Open Assistant。其数据处理脚本遵循了 open-instruct 项目的格式。最方便的方式是直接使用作者在Hugging Face上提供的已处理版本。# 从Hugging Face下载已处理的数据 git lfs install git clone https://huggingface.co/datasets/princeton-nlp/less_data ../data下载后你的../data目录结构应如下所示data/ ├── train/ │ ├── processed/ │ │ ├── flan_v2/ │ │ ├── cot/ │ │ ├── dolly/ │ │ └── oasst1/ │ └── raw/ # 可能包含原始数据 └── eval/ ├── mmlu/ ├── tydiqa/ └── bbh/注意事项务必检查数据格式。每个jsonl文件中的每一行应是一个字典至少包含instruction,input,output字段或类似的prompt-completion对。LESS的数据加载器期望这种格式。如果你有自己的数据集需要先预处理成相同格式。一个常见的错误是数据字段名不匹配导致脚本运行时找不到关键信息而静默失败或报错。4. LESS数据选择流水线分步实现4.1 第一步预热训练预热训练的目的不是得到一个好模型而是为了获得一个合理的参数起点并确定后续梯度采集的检查点。LESS使用全量数据的5%进行LoRA微调。export DATA_DIR../data export MODEL_PATHmeta-llama/Llama-2-7b-hf export PERCENTAGE0.05 export DATA_SEED3 export JOB_NAMEllama2-7b-p${PERCENTAGE}-lora-seed${DATA_SEED} ./less/scripts/train/warmup_lora_train.sh $DATA_DIR $MODEL_PATH $PERCENTAGE $DATA_SEED $JOB_NAME关键参数解析PERCENTAGE0.05: 使用5%的随机数据。这个比例是经验值足以让模型初步学习指令遵循的模式又不会开销太大。DATA_SEED3: 随机种子确保数据可复现。不同的种子会导致采样的5%数据不同可能会影响最终选择结果这是一个可以探索的超参数。JOB_NAME: 用于定义输出目录方便管理。脚本内部做了什么从四个数据集中随机采样指定百分比的数据并合并。使用QLoRA4-bit量化加载Llama-2-7B模型。在合并的数据上训练一个LoRA适配器通常作用于所有线性层。按照固定的间隔如每1000步保存检查点。踩坑记录预热训练的步数和检查点保存频率需要根据你的总数据量和batch size调整。原脚本可能默认训练一个epoch。你需要确保训练足够步数使模型损失明显下降同时保存足够多如4-5个均匀分布的检查点以供后续梯度采集。检查点太少会影响多检查点集成的效果。4.2 第二步构建梯度数据存储库这是LESS最耗资源的步骤但只需执行一次。我们需要为全部训练数据而不仅仅是预热用的5%在多个检查点下计算并存储其投影梯度。CKPT105 # 假设这是第一个检查点 TRAINING_DATA_NAMEdolly TRAINING_DATA_FILE../data/train/processed/dolly/dolly_data.jsonl GRADIENT_TYPEadam # 对于训练数据使用Adam优化器状态中的梯度估计 MODEL_PATH../out/llama2-7b-p0.05-lora-seed3/checkpoint-${CKPT} OUTPUT_PATH../grads/llama2-7b-p0.05-lora-seed3/${TRAINING_DATA_NAME}-ckpt${CKPT}-${GRADIENT_TYPE} DIMS8192 # 投影维度 ./less/scripts/get_info/get_train_lora_grads.sh $TRAINING_DATA_FILE $MODEL_PATH $OUTPUT_PATH $DIMS $GRADIENT_TYPE你需要对每个数据集flan_v2, cot, dolly, oasst1和每个选定的检查点如105, 211, 317, 420都运行此脚本。核心过程解读脚本加载指定检查点的模型包含LoRA权重。遍历指定数据文件的每一个样本。对于每个样本进行前向传播计算损失然后反向传播得到梯度。关键一步梯度投影。脚本会将模型参数的梯度只针对LoRA可训练参数已大幅减少拼接成一个巨大向量然后与一个预先生成的随机高斯矩阵维度为[总参数量, DIMS]相乘得到一个仅DIMS维的投影梯度向量。将这个低维向量连同样本的唯一标识如索引和损失值保存到OUTPUT_PATH目录下的.npy或.pkl文件中。重要提示GRADIENT_TYPEadam是一个精妙的处理。由于我们使用Adam优化器其更新方向不是纯梯度而是经过一阶矩和二阶矩估计修正后的方向。使用Adam状态能更好地模拟实际参数更新方向比原始SGD梯度更有意义。这是论文中的一个重要细节。4.3 第三步为目标任务选择数据现在假设我们想提升模型在tydiqa多语言问答任务上的表现。3.1 计算目标任务的梯度首先我们需要目标验证任务的投影梯度。流程与第二步类似但数据换成了验证集且GRADIENT_TYPE固定为sgd即原始梯度。TASKtydiqa CKPT105 MODEL_PATH../out/llama2-7b-p0.05-lora-seed3/checkpoint-${CKPT} OUTPUT_PATH../grads/llama2-7b-p0.05-lora-seed3/${TASK}-ckpt${CKPT}-sgd DATA_DIR../data DIMS4096 8192 # 可以计算多个维度的投影后续选择使用哪个 ./less/scripts/get_info/get_eval_lora_grads.sh $TASK $DATA_DIR $MODEL_PATH $OUTPUT_PATH $DIMS同样需要对每个检查点运行此脚本。3.2 计算影响力分数并选择Top-K数据这是LESS的“魔法”发生处。脚本会计算每个训练数据样本来自所有数据集、所有检查点与目标任务梯度之间的加权内积。DIM8192 # 选定使用的投影维度 GRADIENT_PATH../grads/llama2-7b-p0.05-lora-seed3/{}-ckpt{}-adam/dim${DIM} TRAIN_FILE_NAMESflan_v2 cot dolly oasst1 CKPTS105 211 317 420 CHECKPOINT_WEIGHTS1.6877e-05 1.2859e-05 7.7030e-06 2.5616e-06 # 对应检查点的平均学习率 VALIDATION_GRADIENT_PATH../grads/llama2-7b-p0.05-lora-seed3/{}-ckpt{}-sgd/dim${DIM} TARGET_TASK_NAMEStydiqa SELECTED_DATA_OUTPUT_PATH../selected_data ./less/scripts/data_selection/matching.sh $GRADIENT_PATH $TRAIN_FILE_NAMES $CKPTS $CHECKPOINT_WEIGHTS $VALIDATION_GRADIENT_PATH $TARGET_TASK_NAMES $SELECTED_DATA_OUTPUT_PATH权重CHECKPOINT_WEIGHTS的奥秘这里使用的是对应检查点时的平均学习率。为什么在影响函数的理论推导中数据点的影响力与学习率η线性相关。在训练后期学习率衰减相同梯度带来的参数更新变小因此其影响力也应该按比例缩放。使用学习率作为权重是对多检查点梯度进行合理加权平均的关键。运行后脚本会为tydiqa任务生成一个包含所有训练样本影响力分数的文件。3.3 提取选定数据最后根据影响力分数排序选出Top K%例如5%的样本。python3 -m less.data_selection.write_selected_data \ --target_task_names ${TARGET_TASK_NAMES} \ --train_file_names ${TRAIN_FILE_NAMES} \ --train_files ../data/train/processed/dolly/dolly_data.jsonl ../data/train/processed/oasst1/oasst1_data.jsonl \ --output_path $SELECTED_DATA_OUTPUT_PATH \ --percentage 0.05这个脚本会从原始训练文件中提取出影响力最高的那部分数据合并成一个新的jsonl文件例如../selected_data/tydiqa/top_p0.05.jsonl。4.4 第四步使用选定数据训练至此我们获得了一份“精华”数据。接下来就是用这份数据对基础模型进行指令微调。TARGET_TASK_NAMEtydiqa PERCENTAGE0.05 TRAIN_FILES../selected_data/${TARGET_TASK_NAME}/top_p${PERCENTAGE}.jsonl MODEL_PATHmeta-llama/Llama-2-7b-hf JOB_NAMEllama2-7b-less-p${PERCENTAGE}-lora ./less/scripts/train/lora_train.sh $TRAIN_FILES $MODEL_PATH $JOB_NAME你可以对比使用全量5%随机数据训练的效果和使用LESS选出的5%数据训练的效果。论文中的实验表明在多项评测上LESS选出的数据都能显著超越随机选择。5. 关键参数调优与经验分享LESS的流程中有几个关键超参数对最终效果有直接影响。5.1 投影维度DIMS的选择投影维度是平衡计算开销和精度的关键。论文实验了{512, 1024, 2048, 4096, 8192}等维度。维度越低存储和计算速度越快但梯度方向的信息损失可能越大可能导致影响力排序不准确。维度越高保真度越高但开销越大。经验值论文默认使用8192这是一个在效果和效率间取得较好平衡的点。对于更大的模型如70B你可能需要尝试更高的维度如16384但会显著增加内存消耗。一个可行的策略是先用小维度如4096跑一遍观察选出的数据是否合理再决定是否增加维度。5.2 预热训练与检查点选择预热数据比例 (PERCENTAGE)默认5%是一个安全的起点。如果你的总数据量极大数千万可以适当降低比例如1%。如果总数据量较小可以提高到10%。核心是让模型学到基本的指令遵循格式。检查点数量与位置论文选择了4个检查点。理想情况下检查点应覆盖训练的不同阶段早期、中期、中后期。你需要根据预热训练的总步长来规划。例如如果训练10,000步可以在[2000, 4000, 6000, 8000]步保存检查点。避免所有检查点都集中在训练开始或结束阶段。5.3 多目标任务的数据选择LESS天然支持为多个目标任务选择数据。你只需在TARGET_TASK_NAMES中指定多个任务例如TARGET_TASK_NAMEStydiqa mmlu bbh。脚本会计算每个训练样本对每个任务的影响力分数。最终的策略可以是并集选取对任一任务影响力高的样本。交集选取对所有任务影响力都高的样本可能数量很少。加权和为不同任务分配权重计算加权影响力总分。 原版脚本似乎采用的是并集策略。你可以修改选择逻辑来实现更复杂的策略。6. 常见问题排查与效能优化6.1 内存不足与计算优化构建梯度数据存储库是内存消耗最大的阶段尤其是处理大规模数据集时。梯度计算批处理原脚本可能是逐样本计算梯度。你可以修改get_train_lora_grads.sh及相关Python代码引入小批量计算。例如每次处理32或64个样本计算平均梯度作为该批次的代表。这能大幅减少磁盘I/O和循环开销是处理超大数据集的必备优化。但需注意批处理会损失样本级粒度是一种近似。分片存储不要试图将一个包含数百万样本的数据集的梯度存成一个巨大的文件。应该按检查点、按数据集分片存储例如每1万个样本存一个文件。这便于管理和后续的分布式读取。使用CPU离线计算如果GPU内存严重不足可以考虑将梯度投影的计算转移到CPU上进行。虽然慢但可行性高。在脚本中将投影矩阵与梯度的矩阵乘法操作放到CPU上执行然后再存回磁盘。6.2 脚本执行错误与调试路径错误LESS脚本大量使用相对路径。确保你在项目根目录LESS/下执行脚本并且DATA_DIR、MODEL_PATH、OUTPUT_PATH等环境变量指向的路径存在且有权访问。数据格式错误如果出现KeyError大概率是数据格式问题。使用一小部分数据打印出数据加载器读取的第一个样本检查字段名是否与代码中的instruction_key、input_key、output_key等匹配。版本冲突确保transformers,accelerate,peft等库的版本与requirements.txt一致。特别是peft库其API在近期版本有较大变动可能导致LoRA模型加载失败。6.3 效果不达预期如何分析如果你复现后发现LESS选出的数据训练效果不如随机选择可以从以下方面排查验证任务与训练数据的相关性LESS的前提是训练数据集中存在能提升目标任务的样本。如果任务本身非常冷门如某种极小众语言的语法分析而你的指令数据集中几乎没有相关数据那么LESS也无能为力。它只能“发现”金子不能“创造”金子。投影维度是否过低尝试将DIM从4096提升到8192或更高看影响力排序是否有显著变化。检查点代表性检查预热训练是否收敛损失曲线是否平滑下降检查点是否捕捉了有代表性的训练状态可以尝试增加预热训练数据比例或步数并保存更多检查点。影响力计算是否正确手动验证一个小样本。随机挑10个训练样本和10个验证样本用代码计算它们的内积看看分数最高的样本是否“看起来”更相关例如验证任务是数学题高分训练样本是否也包含数学推理。这是一种快速的定性检查。LESS为我们提供了一种数据驱动的、可解释的指令数据选择方法。它将宝贵的计算资源从“蛮力训练”转向了“精准投喂”。尽管其初始的梯度计算成本不低但对于需要反复在不同目标任务上微调模型或拥有超大规模指令池的团队来说构建一次梯度数据存储库即可长期、高效地服务于多种能力提升需求从长远看是极具性价比的投资。在实际应用中你可以从一个小规模试点开始例如在某个垂直领域如法律、医疗的指令数据上应用LESS验证其对你特定模型和任务的有效性再决定是否扩大到全量流程。