指南
异步
在吞吐量大、并发能力要求高、训练步骤复杂的场景下,异步方法要比同步方法更具性能优势。具体而言,异步方法能够实现在提交计算任务后不阻塞进程, 而是继续在本地执行其他操作, 从而提升了训练的效率和吞吐量。
通常情况下, TRIO 的训练 API 会同时提供同步与异步版本, 所有异步方法都会以_async结尾。例如前后向传播:
- 同步方法:
forward_backward - 异步方法:
forward_backward_async
使用
以前后向传播为例,同步方法的写法一般如下:
for i in range(15):
fwdbwd_future = training_client.forward_backward(...)
fwdbwd_future = fwdbwd_future.result()而异步方法的写法如下:
async def main():
...
for i in range(15):
fwdbwd_future = await training_client.forward_backward_async(...)
fwdbwd_future = await fwdbwd_future
...
if __name__ == "__main__":
asyncio.run(main())具体来说, 调用异步方法分为两个阶段:
-
提交任务阶段: 通过
await training_client.forward_backward_async(...)提交计算任务, 调用此方法后会立即返回一个APIFutureResult对象, 相当于一个令牌, 代表了这个计算任务的未来结果- 获得
APIFutureResult对象后, 本地可以继续执行其他的操作, 例如准备调度器, 记录指标等 - 计算任务则继续在云端训练服务器上执行
- 获得
-
获取结果阶段: 当需要获取训练结果时, 通过
await fwdbwd_future, 显式阻塞等待训练结果返回- 在调用
await fwdbwd_future时, 本地会进入阻塞态, 等待直到云端计算任务完成并回调结果后, 才会继续执行后续代码 - 通过这种方式, 可以灵活地控制何时需要等待训练结果, 从而提升训练的效率和吞吐量
- 在调用
总结来说,通过异步方法实现的训练流程, 可以协同利用本地和云端的计算资源, 提升整个训练过程的吞吐量。
完整示例
下面是使用异步方法做 SFT 的示例:
import pytrio as trio
import numpy as np
import asyncio
async def main():
# 1. 与TRIO建立连接
service_client = trio.ServiceClient()
# 2. 创建1个训练客户端
base_model = "Qwen/Qwen3-4B-Instruct-2507"
training_client = await service_client.create_lora_training_client_async(
base_model=base_model,
rank=32,
)
# 3. 数据集-让LLM答对什么是trio
examples = [
{"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
{"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
{"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
]
# 4. 获取Tokenizer
print("Loading tokenizer...")
tokenizer = training_client.get_tokenizer()
print("Tokenizer finish")
# 5. 处理数据集,转换为训练需要的格式
def process_example(example: dict, tokenizer) -> trio.Datum:
prompt = f"Question: {example['input']}\nAnswer:"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
prompt_weights = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
completion_weights = [1] * len(completion_tokens)
tokens = prompt_tokens + completion_tokens
weights = prompt_weights + completion_weights
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
weights = weights[1:]
# 转换为trio训练需要的格式
return trio.Datum(
model_input=trio.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
)
processed_examples = [process_example(ex, tokenizer) for ex in examples] # 异步打印损失
# 6. 训练
print("Start Training")
print_task_queue = []
for iter in range(15):
fwdbwd_future = await training_client.forward_backward_async(processed_examples, "cross_entropy") # 前向反向计算
optim_future = await training_client.optim_step_async(trio.AdamParams(learning_rate=1e-4)) # Adam优化器更新
async def print_loss_async(fwdbwd_future, optim_future, iter: int): # 异步打印损失
fwdbwd_result = await fwdbwd_future
optim_result = await optim_future
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
loss = -np.dot(logprobs, weights) / weights.sum()
print(f"Iter{iter+1} Loss per token: {loss:.4f}")
return loss
print_task_queue.append(print_loss_async(fwdbwd_future, optim_future, iter))
await asyncio.gather(*print_task_queue) # 等待所有打印任务完成
# 7. 推理与评估
print("Start Sampling")
sampling_base_client = await service_client.create_sampling_client_async(base_model=base_model)
sampling_sft_client = await training_client.save_weights_and_get_sampling_client_async(name='what-is-trio')
prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])
future_base = await sampling_base_client.sample_async(prompt=prompt, sampling_params=params, num_samples=1)
future_sft = await sampling_sft_client.sample_async(prompt=prompt, sampling_params=params, num_samples=1)
result_base = await future_base
result_sft = await future_sft
print("Base Responses:")
print(f"{repr(result_base.sequences[0].text)}")
print("SFT Responses:")
print(f"{repr(result_sft.sequences[0].text)}")
if __name__ == "__main__":
asyncio.run(main())已支持的异步方法
ServiceClient
| Sync | Async |
|---|---|
create_lora_training_client | create_lora_training_client_async |
create_sampling_client | create_sampling_client_async |
create_training_client_from_state | create_training_client_from_state_async |
create_training_client_from_state_with_optimizer | create_training_client_from_state_with_optimizer_async |
SamplingClient
| Sync | Async |
|---|---|
sample | sample_async |
compute_logprobs | compute_logprobs_async |
TrainingClient
| Sync | Async |
|---|---|
forward | forward_async |
forward_backward | forward_backward_async |
forward_backward_custom | forward_backward_custom_async |
optim_step | optim_step_async |
save_state | save_state_async |
save_weights_for_sampler | save_weights_for_sampler_async |
create_sampling_client | create_sampling_client_async |
save_weights_and_get_sampling_client | save_weights_and_get_sampling_client_async |