损失函数
TRIO 为监督学习和强化学习提供了内置的损失函数。
你可以通过将字符串传递给 forward_backward 来选择损失函数:
future = training_client.forward_backward(
data,
loss_fn="cross_entropy",
)
result = future.result()内置损失函数
目前 TRIO 支持的内置损失函数如下:
| 损失函数 | 适用场景 | 说明 |
|---|---|---|
cross_entropy | 监督学习 | 标准交叉熵损失,适用于分类任务。以模型输出的 logits 和目标标签计算负对数似然。 |
importance_sampling | 离线强化学习 | 使用重要性采样对 off-policy 数据进行修正,通过行为策略与目标策略的概率比值对梯度加权。 |
ppo | 在线强化学习 | Proximal Policy Optimization 损失,通过裁剪概率比值限制策略更新幅度,提升训练稳定性。 |
自定义损失函数
对于内置损失函数之外的使用场景,TRIO 提供了更灵活(但速度较慢)的方法 forward_backward_custom 来计算更通用的损失函数。
forward_backward_custom接收数据和模型的logprobs,并返回loss以及可选的指标。
比如我们希望实现1个损失函数,它的逻辑是希望每个 logprob 尽可能接近 0(也就是概率接近 1):
实现代码为:
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
loss = (flat_logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}使用 forward_backward_custom 调用它:
future = training_client.forward_backward_custom(data, logprob_squared_loss)
result = future.result()
print(f"Loss: {result.loss}, Metrics: {result.metrics}")我们改造一下 SFT 示例为自定义损失函数:
import pytrio as trio
import torch
# 1. 与TRIO建立连接
service_client = trio.ServiceClient()
# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = service_client.create_lora_training_client(
base_model=base_model,
rank=32,
)
# 3. 数据集-让LLM答对什么是trio
examples = [
{"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
{"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
{"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]
# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")
# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
prompt = f"Question: {example['input']}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
# 转换为trio训练需要的格式
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples]
# 6. 自定义损失函数
def logprob_squared_loss(data: list[trio.Datum], logprobs: list[torch.Tensor]) -> tuple[torch.Tensor, dict[str, float]]:
flat_logprobs = torch.cat(logprobs)
loss = (flat_logprobs ** 2).sum()
return loss, {"logprob_squared_loss": loss.item()}
# 7. 训练
print("Start Training")
for iter in range(15):
fwdbwd_future = training_client.forward_backward_custom(processed_examples, logprob_squared_loss)
optim_future = training_client.optim_step(trio.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result()
optim_result = optim_future.result()
print(f"Iter{iter+1} Logprob_squared_loss: {fwdbwd_result.metrics['logprob_squared_loss']:.4f}")
# 7. 推理与评估
print("Start Sampling")
sampling_base_client = service_client.create_sampling_client(base_model=base_model)
training_client.save_state(name="Train")
sampling_sft_client = training_client.save_weights_and_get_sampling_client(name='what-is-trio')
prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])
future_base = sampling_base_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_base = future_base.result()
future_sft = sampling_sft_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
result_sft = future_sft.result()
print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")
print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")输出结果为:
Tokenizer finish
Start Training
Iter1 Logprob_squared_loss: 2173.7051
...
Iter15 Logprob_squared_loss: 48.7835
Start Sampling
Base Responses:
' A trio is a musical ensemble consisting of three performers. The term can also refer to a group of'
SFT Responses:
' trio is emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine emotionmachine'可以看到使用了自定义损失函数进行了训练(ps:logprob_squared_loss 只是个用于示例的损失函数,实际效果并不好,请勿使用到自己的训练中。)
forward_backward_custom 的工作原理
forward_backward_custom 允许用户基于 target token 的 logprobs 定义任意可微损失函数,同时无需将自定义函数序列化、上传或在服务器端执行。其核心思想是:将原始的非线性损失分解为一次前向计算和一次基于替代目标函数(surrogate objective)的前向反向计算。该替代目标函数虽然在线性形式上定义于 logprobs,但其对模型参数的梯度与原始损失完全一致。
数学形式
设模型参数为 params,target token 的 logprobs 为:
logprobs = compute_target_logprobs(params)用户定义的原始损失为:
loss = compute_loss_from_logprobs(logprobs)即:
其中:
- 表示模型参数;
- 表示目标 token 的 logprobs;
- 表示用户在客户端定义的任意可微损失函数。
为了在不执行 f 的情况下仍然得到正确梯度,TRIO 构造如下 surrogate loss:
surrogate_loss = (logprobs * logprob_grads).sum()
# where logprob_grads = dLoss/dLogprobs即:
其中 由客户端基于原始损失计算得到,并作为常数权重传回服务器。
根据链式法则,有:
而 surrogate loss 的梯度为:
因此:
这说明,尽管 surrogate loss 的形式不同于原始损失,其对模型参数产生的梯度是严格等价的。
执行流程
forward_backward_custom 在客户端与服务器之间分两阶段完成梯度计算:
-
准备数据 客户端构造
Datum对象列表,并准备目标 token 信息。 -
前向计算 服务器执行一次 forward,计算目标 token 的 logprobs。
-
客户端计算自定义损失 客户端使用返回的 logprobs 调用用户定义的
custom_fn(logprobs),得到标量损失。 -
客户端反向传播到 logprobs 客户端对该损失执行反向传播,得到 ,即每个 logprob 对最终损失的梯度。
-
服务器执行 surrogate forward-backward 服务器使用这些梯度作为权重,构造 surrogate loss:
并对其执行 forward-backward,从而得到与原始自定义损失完全一致的参数梯度。
为什么不需要上传自定义函数
在这一设计中,服务器只需要:
- 计算目标 token 的 logprobs;
- 接收客户端返回的 ;
- 对 surrogate objective 执行标准的梯度计算。
因此,用户定义的 Python 函数始终保留在客户端执行。TRIO 不会对其进行 pickle,也不会将其发送到服务器。
性能开销
由于 forward_backward_custom 需要额外执行一次 forward,其计算开销高于单次 forward_backward:
- FLOPs 约为单次
forward_backward的 1.5×; - 实际耗时 在某些情况下可达到 最多约 3×,这主要来自额外的前向计算以及 forward/backward 调度与客户端-服务器往返带来的实现开销。