指南

损失函数

TRIO 为监督学习和强化学习提供了内置的损失函数。

你可以通过将字符串传递给 forward_backward 来选择损失函数:

future = training_client.forward_backward(
    data,
    loss_fn="cross_entropy", 
    )
result = future.result()

内置损失函数

目前 TRIO 支持的内置损失函数如下:

损失函数适用场景说明
cross_entropy监督学习标准交叉熵损失,适用于分类任务。以模型输出的 logits 和目标标签计算负对数似然。
importance_sampling离线强化学习使用重要性采样对 off-policy 数据进行修正,通过行为策略与目标策略的概率比值对梯度加权。
ppo在线强化学习Proximal Policy Optimization 损失,通过裁剪概率比值限制策略更新幅度,提升训练稳定性。

自定义损失函数

对于内置损失函数之外的使用场景,TRIO 提供了更灵活(但速度较慢)的方法 forward_backward_custom 来计算更通用的损失函数。

forward_backward_custom接收数据和模型的logprobs,并返回loss以及可选的指标。

比如我们希望实现1个损失函数,它的逻辑是希望每个 logprob 尽可能接近 0(也就是概率接近 1):

loss=i(logpi)2\text{loss} = \sum_i (\log p_i)^2

实现代码为:

def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
    flat_logprobs = torch.cat(logprobs)
    loss = (flat_logprobs ** 2).sum()
    return loss, {"logprob_squared_loss": loss.item()}

使用 forward_backward_custom 调用它:

future = training_client.forward_backward_custom(data, logprob_squared_loss)
result = future.result()
print(f"Loss: {result.loss}, Metrics: {result.metrics}")

我们改造一下 SFT 示例为自定义损失函数:

import pytrio as trio
import torch

# 1. 与TRIO建立连接
service_client = trio.ServiceClient()

# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
    base_model=base_model,
    rank=32,
)

# 3. 数据集-让LLM答对什么是trio
examples = [
    {"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
    {"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
    {"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]

# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")

# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
    prompt = f"Question: {example['input']}\nAnswer:"

    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_weights = [0] * len(prompt_tokens)
    
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
    completion_weights = [1] * len(completion_tokens)

    tokens = prompt_tokens + completion_tokens
    weights = prompt_weights + completion_weights

    input_tokens = tokens[:-1]
    target_tokens = tokens[1:]
    weights = weights[1:]
    
    # 转换为trio训练需要的格式
    return trio.Datum(
        model_input=trio.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
    )

processed_examples = [process_example(ex, tokenizer) for ex in examples]

# 6. 自定义损失函数
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]: 
    flat_logprobs = torch.cat(logprobs)
    loss = (flat_logprobs ** 2).sum()
    return loss, {"logprob_squared_loss": loss.item()}

# 7. 训练
print("Start Training")
for iter in range(15):
    fwdbwd_future = training_client.forward_backward_custom(processed_examples, logprob_squared_loss) 
    optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))

    fwdbwd_result = fwdbwd_future.result()
    optim_result = optim_future.result()
    
    print(f"Iter{iter+1} Logprob_squared_loss: {fwdbwd_result.metrics['logprob_squared_loss']:.4f}")

# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
training_client.save_state(name="Train")
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')

prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])

future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()

print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")

print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")

输出结果为:

Tokenizer finish
Start Training
Iter1 Logprob_squared_loss: 2173.7051
...
Iter15 Logprob_squared_loss: 48.7835
Start Sampling
Base Responses:
' A trio is a musical ensemble consisting of three performers. The term can also refer to a group of'
SFT Responses:
' trio is emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine'

可以看到使用了自定义损失函数进行了训练(ps:logprob_squared_loss 只是个用于示例的损失函数,实际效果并不好,请勿使用到自己的训练中。)

forward_backward_custom 的工作原理

forward_backward_custom 允许用户基于 target token 的 logprobs 定义任意可微损失函数,同时无需将自定义函数序列化、上传或在服务器端执行。其核心思想是:将原始的非线性损失分解为一次前向计算和一次基于替代目标函数(surrogate objective)的前向反向计算。该替代目标函数虽然在线性形式上定义于 logprobs,但其对模型参数的梯度与原始损失完全一致。

数学形式

设模型参数为 params,target token 的 logprobs 为:

logprobs = compute_target_logprobs(params)

用户定义的原始损失为:

loss = compute_loss_from_logprobs(logprobs)

即:

L(θ)=f(z(θ))L(\theta) = f(z(\theta))

其中:

  • θ\theta 表示模型参数;
  • z(θ)z(\theta) 表示目标 token 的 logprobs;
  • ff 表示用户在客户端定义的任意可微损失函数。

为了在不执行 f 的情况下仍然得到正确梯度,TRIO 构造如下 surrogate loss:

surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs

即:

L~(θ)=izi(θ)Lzi\tilde{L}(\theta) = \sum_i z_i(\theta) \cdot \frac{\partial L}{\partial z_i}

其中 Lzi\frac{\partial L}{\partial z_i} 由客户端基于原始损失计算得到,并作为常数权重传回服务器。

根据链式法则,有:

Lθ=iLziziθ\frac{\partial L}{\partial \theta} = \sum_i \frac{\partial L}{\partial z_i} \cdot \frac{\partial z_i}{\partial \theta}

而 surrogate loss 的梯度为:

L~θ=iLziziθ\frac{\partial \tilde{L}}{\partial \theta} = \sum_i \frac{\partial L}{\partial z_i} \cdot \frac{\partial z_i}{\partial \theta}

因此:

L~θ=Lθ\frac{\partial \tilde{L}}{\partial \theta} = \frac{\partial L}{\partial \theta}

这说明,尽管 surrogate loss 的形式不同于原始损失,其对模型参数产生的梯度是严格等价的。

执行流程

forward_backward_custom 在客户端与服务器之间分两阶段完成梯度计算:

  1. 准备数据 客户端构造 Datum 对象列表,并准备目标 token 信息。

  2. 前向计算 服务器执行一次 forward,计算目标 token 的 logprobs。

  3. 客户端计算自定义损失 客户端使用返回的 logprobs 调用用户定义的 custom_fn(logprobs),得到标量损失。

  4. 客户端反向传播到 logprobs 客户端对该损失执行反向传播,得到 Llogprobs\frac{\partial L}{\partial \text{logprobs}},即每个 logprob 对最终损失的梯度。

  5. 服务器执行 surrogate forward-backward 服务器使用这些梯度作为权重,构造 surrogate loss:

    ilogprobsiLlogprobsi\sum_i \text{logprobs}_i \cdot \frac{\partial L}{\partial \text{logprobs}_i}

    并对其执行 forward-backward,从而得到与原始自定义损失完全一致的参数梯度。

为什么不需要上传自定义函数

在这一设计中,服务器只需要:

  • 计算目标 token 的 logprobs;
  • 接收客户端返回的 Llogprobs\frac{\partial L}{\partial \text{logprobs}}
  • 对 surrogate objective 执行标准的梯度计算。

因此,用户定义的 Python 函数始终保留在客户端执行。TRIO 不会对其进行 pickle,也不会将其发送到服务器。

性能开销

由于 forward_backward_custom 需要额外执行一次 forward,其计算开销高于单次 forward_backward

  • FLOPs 约为单次 forward_backward1.5×
  • 实际耗时 在某些情况下可达到 最多约 3×,这主要来自额外的前向计算以及 forward/backward 调度与客户端-服务器往返带来的实现开销。

On this page