训练
你可以使用 TRIO 在你的 CPU 机器上开启大规模 LLM 后训练,而无需关注 Infra和环境带来的复杂性。
后训练分为监督微调(SFT)和强化学习(RL),其中 TRIO 执行 SFT 的逻辑如下:

TRIO 执行 RL 的逻辑如下:

如果我们抽象出共性,会发现数据处理(Datum)、前后反向计算(forward_backward)以及权重更新(optim_step)总是必不可少,接下来让我们详细拆解一下这三个模块,以及在 SFT 和 RL 中如何使用它们。
数据处理(Datum)
在将数据集传入LLM和损失函数之前,我们需要对它们进行处理。
Datum 是一个用于处理训练数据的封装,你需要将数据转换为Datum才能输入到forward_backward。它包含两部分:
model_input:提供给损失函数的 tokensloss_fn_inputs:提供给损失函数的参数。不同的后训练任务需要不同的参数:-
对于 SFT损失函数(cross_entropy) ,你需要输入:
weights:每token损失权重,是一个由 0 和 1 组成的列表(0 = ignore,1 = compute loss),往往prompt的部分为0,output的部分为1target_tokens:tokens 向右移一位,往往是tokens[1:]
datum = trio.Datum( model_input=trio.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=dict( weights=weights, target_tokens=target_tokens, ) ) -
对于 RL损失函数(importance_sampling、ppo),除了
weights和target_tokens,还需要输入参数:logprobs:从rollout中拿到的sampling logprobsadvantages:奖励优势值
rl_datum = trio.Datum( model_input=trio.ModelInput.from_ints(tokens=input_tokens), loss_fn_inputs=dict( weights=weights, target_tokens=target_tokens, logprobs=sampling_logprobs, advantages=advantages, ) )
-
理解概念以后,我们来个更接近实际的例子。假如我们有一个数据集,想做 SFT 任务,组 Datum 的流程如下:
- 准备一个数据集:
examples = [
{"input": "1+1", "output": "2"},
{"input": "1+2", "output": "3"},
{"input": "2*3", "output": "6"},
]- 将数据集的每条数据转换为对应的
input_tokens、target_tokens、weights,输入到Datum中,最后得到一个Datum组成的数据集processed_examples:
def process_example(example, tokenizer):
prompt = f"Formula: {example['input']}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples]- 最后,将
processed_examples传到training_client的forward_backward中,即可开启训练:
import numpy as np
for _ in range(6):
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))
...计算积累梯度(forward_backward)
处理好数据集后,我们需要将Datum传入forward_backward计算积累梯度。
forward_backward正如它的名字一样,数据将经过LLM做一次前向传播(forward),再将前向结果、目标值与损失函数做一次反向传播(backward),从而得到更新权重所需的梯度。
forward_backward需要2个参数:
data:Datum组成的列表,每个元素都包含 input tokens 和 损失函数参数loss_fn:计算梯度使用的损失函数,可以是cross_entropy、importance_sampling和ppo这种内置函数,也可以是自定义函数。具体数学定义可见 损失函数。
fwdbwd_future = training_client.forward_backward(
data=[datum],
loss_fn="cross_entropy"
)
fwdbwd_future = fwdbwd_future.result()每次计算的返回值中包含:
loss_fn_outputs:每条样本的logprobs和elementwise_lossmetrics:计算得到的指标,比如损失值loss:sumelapsed_time:处理耗时,单位为毫秒
fwdbwd_result.loss_fn_outputs
# [{'logprobs':..., 'elementwise_loss':...}, ...]
fwdbwd_result.metrics
# {'loss:sum': ...}
fwdbwd_result.elapsed_time
# 460.0如果想要计算 per token loss,可以通过metrics:
loss_sum = fwdbwd_result.metrics['loss:sum']
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss: {loss_sum / weights.sum():.4f}")也可以通过loss_fn_outputs里的logprobs:
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")这两者是完全等价的。
更新权重(optim_step)
拿到积累梯度后,接下来就是用优化器更新权重。
执行optim_step后,TRIO会根据forward_backward中积累的梯度,更新权重:
optim_future = training_client.optim_step(
trio.AdamParams(learning_rate=1e-4)
)
optim_future = optim_future.result()optim_step 只有一个参数 adam_params,需要传入1个trio.AdamParams对象。
trio.AdamParams可以选择的参数有:
learning_rate:学习率,默认为1e-4beta1:AdamW beta1,默认为0.9beta2:AdamW beta2,默认为0.999eps:AdamW epsilon,默认为1e-12weight_decay:权重衰减系数,默认为0
监督微调(SFT)
下面是一个监督微调案例代码,目标是让模型回答 TRIO 是 1 个 AI Infra 产品:
import pytrio as trio
import numpy as np
# 1. 与TRIO建立连接
service_client = trio.ServiceClient()
# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
base_model=base_model,
rank=32,
)
# 3. 数据集-让LLM答对什么是trio
examples = [
{"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
{"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
{"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]
# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")
# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
prompt = f"Question: {example['input']}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
# 转换为trio训练需要的格式
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples]
# 6. 训练
print("Start Training")
for iter in range(15):
fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy") # 前向反向计算
optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4)) # Adam优化器更新
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Iter{iter+1} Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")
# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')
prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])
future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()
print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")
print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")强化学习(RL)
下面是一个强化学习案例代码,目标是让模型遵照格式正确回答数学题:
import re
import pytrio as trio
import numpy as np
# 1. 与TRIO建立连接
service_client = trio.ServiceClient()
# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
base_model=base_model,
rank=32,
)
# 3. 数据集-让LLM做简单数学题
dataset = [
("What is 2 + 3?", 5),
("What is 7 - 4?", 3),
("What is 6 * 8?", 48),
("What is 12 / 3?", 4),
("Solve for x: x + 5 = 9", 4),
("Solve for x: 2x = 10", 5),
("What is 3 squared?", 9),
("What is the square root of 81?", 9),
("What is 15 + 27?", 42),
("What is 100 - 58?", 42),
]
eval_dataset = [
("Solve for x: x + 7 = 12", 5),
("What is 9 * 7?", 63),
("What is 81 / 9?", 9),
("What is 14 + 28?", 42),
]
# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")
# 6. 从模型输出中解析数字答案
def parse_number(text: str):
match = re.fullmatch(r"-?\d+(?:\.\d+)?", text.strip())
return float(match.group()) if match else None
# 7. 奖励函数
def compute_reward(text: str, gold: float) -> float:
pred = parse_number(text)
if pred is None:
return -1.0
if abs(pred - gold) < 1e-6:
return 2.0
return -0.5
# 8. 转成numpy数组,方便后面统计loss
def to_np(x):
return np.array(x.tolist() if hasattr(x, "tolist") else x, dtype=float)
# 9. 把一次采样结果处理成trio训练需要的Datum格式
def process_rollout(prompt_tokens, completion_tokens, completion_logprobs, reward_value):
tokens = prompt_tokens + completion_tokens
prompt_weights = [0] * len(prompt_tokens)
completion_weights = [1] * len(completion_tokens)
weights = prompt_weights + completion_weights
old_logprobs = ([0.0] * len(prompt_tokens) + list(completion_logprobs))[:len(tokens)]
old_logprobs += [0.0] * (len(tokens) - len(old_logprobs))
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
old_logprobs = old_logprobs[1:]
advantages = [reward_value] * (len(tokens) - 1)
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(
weights=weights,
target_tokens=target_tokens,
logprobs=old_logprobs,
advantages=advantages,
),
)
# 11. RL训练
print("Start RL Training")
for iter in range(15):
sampler = training_client.save_weights_and_get_sampling_client(name=f"rl-math-sampler-iter{iter}")
processed_examples = []
rewards = []
correct = 0
total = 0
for question, gold in dataset:
prompt_tokens = tokenizer.encode(f"Question: {question}\nReturn only the final numeric answer.\nAnswer:", add_special_tokens=True)
future_sample = sampler.sample(
prompt=trio.ModelInput.from_ints(prompt_tokens),
sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.7),
num_samples=4,
)
sample_result = future_sample.result()
for sequence in sample_result.sequences:
reward_value = compute_reward(sequence.text, float(gold))
pred = parse_number(sequence.text)
rewards.append(reward_value)
total += 1
correct += pred is not None and abs(pred - gold) < 1e-6
completion_tokens = tokenizer.encode(sequence.text, add_special_tokens=False)
if completion_tokens:
processed_examples.append(
process_rollout(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
completion_logprobs=sequence.logprobs,
reward_value=reward_value,
)
)
print(
f"Iter{iter+1} | Reward: {np.mean(rewards):.4f} | "
f"Acc: {correct / max(total, 1):.4f} | Samples: {len(processed_examples)}"
)
fwdbwd_future = training_client.forward_backward(processed_examples, "importance_sampling")
optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-5))
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
logprobs = np.concatenate([to_np(output["logprobs"]) for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([to_np(example.loss_fn_inputs["weights"]) for example in processed_examples])
old_logprobs = np.concatenate([to_np(example.loss_fn_inputs["logprobs"]) for example in processed_examples])
advantages = np.concatenate([to_np(example.loss_fn_inputs["advantages"]) for example in processed_examples])
mask = weights > 0
loss = -np.sum(np.exp(logprobs[mask] - old_logprobs[mask]) * advantages[mask]) / mask.sum()
print(f"Iter{iter+1} IS Loss: {loss:.4f}\n")
# 13. 推理与评估
print("Start Evaluation")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
sampling_rl_client = training_client.save_weights_and_get_sampling_client(name="math-rl-final")
for question, gold in eval_dataset:
prompt = trio.ModelInput.from_ints(
tokenizer.encode(f"Question: {question}\nReturn only the final numeric answer.\nAnswer:", add_special_tokens=True)
)
future_base = sampling_base_client.sample(prompt=prompt, sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.0), num_samples=1)
future_rl = sampling_rl_client.sample(prompt=prompt, sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.0), num_samples=1)
result_base = future_base.result()
result_rl = future_rl.result()
base_text = result_base.sequences[0].text.strip()
rl_text = result_rl.sequences[0].text.strip()
print("=" * 60)
print(f"Q: {question} | Gold: {gold}")
print(f"Base: {repr(base_text)} -> {parse_number(base_text)}")
print(f"RL: {repr(rl_text)} -> {parse_number(rl_text)}")训练结果如下:
Iter1 | Reward: -0.5375 | Acc: 0.1500 | Samples: 40
Iter1 IS Loss: 0.7964
Iter2 | Reward: -0.5375 | Acc: 0.1500 | Samples: 40
Iter2 IS Loss: 0.7938
...
Iter15 | Reward: 1.3250 | Acc: 0.7750 | Samples: 40
Iter15 IS Loss: -0.7395
Start Evaluation
============================================================
Q: Solve for x: x + 7 = 12 | Gold: 5
Base: '5\n\nQuestion: Solve for x' -> None
RL: '5' -> 5.0
============================================================
Q: What is 9 * 7? | Gold: 63
Base: '63. 63.' -> None
RL: '63' -> 63.0
============================================================
Q: What is 81 / 9? | Gold: 9
Base: '9\n\nQuestion: What is' -> None
RL: '9' -> 9.0
============================================================
Q: What is 14 + 28? | Gold: 42
Base: '42.' -> None
RL: '42' -> 42.0可以看到经过 RL 训练之后的模型相比先前更符合格式要求。