指南

训练

你可以使用 TRIO 在你的 CPU 机器上开启大规模 LLM 后训练,而无需关注 Infra和环境带来的复杂性。

后训练分为监督微调(SFT)和强化学习(RL),其中 TRIO 执行 SFT 的逻辑如下:

TRIO 执行 RL 的逻辑如下:

如果我们抽象出共性,会发现数据处理(Datum)、前后反向计算(forward_backward)以及权重更新(optim_step)总是必不可少,接下来让我们详细拆解一下这三个模块,以及在 SFT 和 RL 中如何使用它们。

数据处理(Datum)

在将数据集传入LLM和损失函数之前,我们需要对它们进行处理。

Datum 是一个用于处理训练数据的封装,你需要将数据转换为Datum才能输入到forward_backward。它包含两部分:

  • model_input:提供给损失函数的 tokens
  • loss_fn_inputs:提供给损失函数的参数。不同的后训练任务需要不同的参数:
    • 对于 SFT损失函数(cross_entropy) ,你需要输入:

      • weights:每token损失权重,是一个由 0 和 1 组成的列表(0 = ignore,1 = compute loss),往往prompt的部分为0,output的部分为1
      • target_tokens:tokens 向右移一位,往往是 tokens[1:]
      datum = trio.Datum(
          model_input=trio.ModelInput.from_ints(tokens=input_tokens),
          loss_fn_inputs=dict(
              weights=weights,           
              target_tokens=target_tokens,
          )
      )
    • 对于 RL损失函数(importance_sampling、ppo),除了weightstarget_tokens,还需要输入参数:

      • logprobs:从rollout中拿到的sampling logprobs
      • advantages:奖励优势值
      rl_datum = trio.Datum(
          model_input=trio.ModelInput.from_ints(tokens=input_tokens),
          loss_fn_inputs=dict(
              weights=weights,
              target_tokens=target_tokens,
              logprobs=sampling_logprobs,
              advantages=advantages,
          )
      )

理解概念以后,我们来个更接近实际的例子。假如我们有一个数据集,想做 SFT 任务,组 Datum 的流程如下:

  1. 准备一个数据集:
examples = [
    {"input": "1+1", "output": "2"},
    {"input": "1+2", "output": "3"},
    {"input": "2*3", "output": "6"},
]
  1. 将数据集的每条数据转换为对应的input_tokenstarget_tokensweights,输入到Datum中,最后得到一个Datum组成的数据集processed_examples
def process_example(example, tokenizer):
    prompt = f"Formula: {example['input']}\nAnswer:"
    
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)
    
    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights
    
    
    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    
    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),  
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens) 
    )

processed_examples = [process_example(ex, tokenizer) for ex in examples]
  1. 最后,将processed_examples传到training_clientforward_backward中,即可开启训练:
import numpy as np

for _ in range(6):
    fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
    optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))
    ...

计算积累梯度(forward_backward)

处理好数据集后,我们需要将Datum传入forward_backward计算积累梯度。

forward_backward正如它的名字一样,数据将经过LLM做一次前向传播(forward),再将前向结果、目标值与损失函数做一次反向传播(backward),从而得到更新权重所需的梯度。

forward_backward需要2个参数:

  • dataDatum组成的列表,每个元素都包含 input tokens 和 损失函数参数
  • loss_fn:计算梯度使用的损失函数,可以是 cross_entropyimportance_samplingppo 这种内置函数,也可以是自定义函数。具体数学定义可见 损失函数
fwdbwd_future = training_client.forward_backward(
    data=[datum],
    loss_fn="cross_entropy"
)
fwdbwd_future = fwdbwd_future.result()

每次计算的返回值中包含:

  1. loss_fn_outputs:每条样本的 logprobselementwise_loss
  2. metrics:计算得到的指标,比如损失值 loss:sum
  3. elapsed_time:处理耗时,单位为毫秒
fwdbwd_result.loss_fn_outputs
# [{'logprobs':..., 'elementwise_loss':...}, ...]

fwdbwd_result.metrics
# {'loss:sum': ...}

fwdbwd_result.elapsed_time
# 460.0

如果想要计算 per token loss,可以通过metrics

loss_sum = fwdbwd_result.metrics['loss:sum']
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss: {loss_sum / weights.sum():.4f}")

也可以通过loss_fn_outputs里的logprobs

logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

这两者是完全等价的。

更新权重(optim_step)

拿到积累梯度后,接下来就是用优化器更新权重。

执行optim_step后,TRIO会根据forward_backward中积累的梯度,更新权重:

optim_future = training_client.optim_step(
    trio.AdamParams(learning_rate=1e-4)
    )
optim_future = optim_future.result()

optim_step 只有一个参数 adam_params,需要传入1个trio.AdamParams对象。

trio.AdamParams可以选择的参数有:

  • learning_rate:学习率,默认为1e-4
  • beta1:AdamW beta1,默认为0.9
  • beta2:AdamW beta2,默认为0.999
  • eps:AdamW epsilon,默认为1e-12
  • weight_decay:权重衰减系数,默认为0

监督微调(SFT)

下面是一个监督微调案例代码,目标是让模型回答 TRIO 是 1 个 AI Infra 产品:

import pytrio as trio
import numpy as np

# 1. 与TRIO建立连接
service_client = trio.ServiceClient()

# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
    base_model=base_model,
    rank=32,
)

# 3. 数据集-让LLM答对什么是trio
examples = [
    {"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
    {"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
    {"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]

# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")

# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
    prompt = f"Question: {example['input']}\nAnswer:"

    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)

    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    
    # 转换为trio训练需要的格式
    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )

processed_examples = [process_example(ex, tokenizer) for ex in examples]

# 6. 训练
print("Start Training")
for iter in range(15):
    fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")  # 前向反向计算
    optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))  # Adam优化器更新

    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()

    logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
    weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
    print(f"Iter{iter+1} Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")

# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')

prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])

future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()

print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")

print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")

强化学习(RL)

下面是一个强化学习案例代码,目标是让模型遵照格式正确回答数学题:

import re
import pytrio as trio
import numpy as np

# 1. 与TRIO建立连接
service_client = trio.ServiceClient()

# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
    base_model=base_model,
    rank=32,
)

# 3. 数据集-让LLM做简单数学题
dataset = [
    ("What is 2 + 3?", 5),
    ("What is 7 - 4?", 3),
    ("What is 6 * 8?", 48),
    ("What is 12 / 3?", 4),
    ("Solve for x: x + 5 = 9", 4),
    ("Solve for x: 2x = 10", 5),
    ("What is 3 squared?", 9),
    ("What is the square root of 81?", 9),
    ("What is 15 + 27?", 42),
    ("What is 100 - 58?", 42),
]

eval_dataset = [
    ("Solve for x: x + 7 = 12", 5),
    ("What is 9 * 7?", 63),
    ("What is 81 / 9?", 9),
    ("What is 14 + 28?", 42),
]

# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")

# 6. 从模型输出中解析数字答案
def parse_number(text: str):
    match = re.fullmatch(r"-?\d+(?:\.\d+)?", text.strip())
    return float(match.group()) if match else None

# 7. 奖励函数
def compute_reward(text: str, gold: float) -> float:
    pred = parse_number(text)
    if pred is None:
        return -1.0
    if abs(pred - gold) < 1e-6:
        return 2.0
    return -0.5

# 8. 转成numpy数组,方便后面统计loss
def to_np(x):
    return np.array(x.tolist() if hasattr(x, "tolist") else x, dtype=float)

# 9. 把一次采样结果处理成trio训练需要的Datum格式
def process_rollout(prompt_tokens, completion_tokens, completion_logprobs, reward_value):
    tokens = prompt_tokens + completion_tokens

    prompt_weights = [0] * len(prompt_tokens)
    completion_weights = [1] * len(completion_tokens)
    weights = prompt_weights + completion_weights

    old_logprobs = ([0.0] * len(prompt_tokens) + list(completion_logprobs))[:len(tokens)]
    old_logprobs += [0.0] * (len(tokens) - len(old_logprobs))

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    old_logprobs = old_logprobs[1:]
    advantages = [reward_value] * (len(tokens) - 1)

    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(
            weights=weights,
            target_tokens=target_tokens,
            logprobs=old_logprobs,
            advantages=advantages,
        ),
    )

# 11. RL训练
print("Start RL Training")

for iter in range(15):
    sampler = training_client.save_weights_and_get_sampling_client(name=f"rl-math-sampler-iter{iter}")
    processed_examples = []
    rewards = []
    correct = 0
    total = 0

    for question, gold in dataset:
        prompt_tokens = tokenizer.encode(f"Question: {question}\nReturn only the final numeric answer.\nAnswer:", add_special_tokens=True)

        future_sample = sampler.sample(
            prompt=trio.ModelInput.from_ints(prompt_tokens),
            sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.7),
            num_samples=4,
        )
        sample_result = future_sample.result()

        for sequence in sample_result.sequences:
            reward_value = compute_reward(sequence.text, float(gold))
            pred = parse_number(sequence.text)

            rewards.append(reward_value)
            total += 1
            correct += pred is not None and abs(pred - gold) < 1e-6

            completion_tokens = tokenizer.encode(sequence.text, add_special_tokens=False)

            if completion_tokens:
                processed_examples.append(
                    process_rollout(
                        prompt_tokens=prompt_tokens,
                        completion_tokens=completion_tokens,
                        completion_logprobs=sequence.logprobs,
                        reward_value=reward_value,
                    )
                )

    print(
        f"Iter{iter+1} | Reward: {np.mean(rewards):.4f} | "
        f"Acc: {correct / max(total, 1):.4f} | Samples: {len(processed_examples)}"
    )

    fwdbwd_future = training_client.forward_backward(processed_examples, "importance_sampling")
    optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-5))

    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()

    logprobs = np.concatenate([to_np(output["logprobs"]) for output in fwdbwd_result.loss_fn_outputs])
    weights = np.concatenate([to_np(example.loss_fn_inputs["weights"]) for example in processed_examples])
    old_logprobs = np.concatenate([to_np(example.loss_fn_inputs["logprobs"]) for example in processed_examples])
    advantages = np.concatenate([to_np(example.loss_fn_inputs["advantages"]) for example in processed_examples])

    mask = weights > 0
    loss = -np.sum(np.exp(logprobs[mask] - old_logprobs[mask]) * advantages[mask]) / mask.sum()
    print(f"Iter{iter+1} IS Loss: {loss:.4f}\n")

# 13. 推理与评估
print("Start Evaluation")

sampling_base_client = service_client.create_sampling_client(base_model=base_model)
sampling_rl_client = training_client.save_weights_and_get_sampling_client(name="math-rl-final")

for question, gold in eval_dataset:
    prompt = trio.ModelInput.from_ints(
        tokenizer.encode(f"Question: {question}\nReturn only the final numeric answer.\nAnswer:", add_special_tokens=True)
    )

    future_base = sampling_base_client.sample(prompt=prompt, sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.0), num_samples=1)
    future_rl = sampling_rl_client.sample(prompt=prompt, sampling_params=trio.SamplingParams(max_tokens=8, temperature=0.0), num_samples=1)
    
    result_base = future_base.result()
    result_rl = future_rl.result()
    
    base_text = result_base.sequences[0].text.strip()
    rl_text = result_rl.sequences[0].text.strip()

    print("=" * 60)
    print(f"Q: {question} | Gold: {gold}")
    print(f"Base: {repr(base_text)} -> {parse_number(base_text)}")
    print(f"RL:   {repr(rl_text)} -> {parse_number(rl_text)}")

训练结果如下:

Iter1 | Reward: -0.5375 | Acc: 0.1500 | Samples: 40
Iter1 IS Loss: 0.7964

Iter2 | Reward: -0.5375 | Acc: 0.1500 | Samples: 40
Iter2 IS Loss: 0.7938

...

Iter15 | Reward: 1.3250 | Acc: 0.7750 | Samples: 40
Iter15 IS Loss: -0.7395

Start Evaluation
============================================================
Q: Solve for x: x + 7 = 12 | Gold: 5
Base: '5\n\nQuestion: Solve for x' -> None
RL:   '5' -> 5.0
============================================================
Q: What is 9 * 7? | Gold: 63
Base: '63. 63.' -> None
RL:   '63' -> 63.0
============================================================
Q: What is 81 / 9? | Gold: 9
Base: '9\n\nQuestion: What is' -> None
RL:   '9' -> 9.0
============================================================
Q: What is 14 + 28? | Gold: 42
Base: '42.' -> None
RL:   '42' -> 42.0

可以看到经过 RL 训练之后的模型相比先前更符合格式要求。

On this page