指南

异步

在吞吐量大、并发能力要求高、训练步骤复杂的场景下,异步方法要比同步方法更具性能优势。具体而言,异步方法能够实现在提交计算任务后不阻塞进程, 而是继续在本地执行其他操作, 从而提升了训练的效率和吞吐量。

通常情况下, TRIO 的训练 API 会同时提供同步与异步版本, 所有异步方法都会以_async结尾。例如前后向传播:

  • 同步方法:forward_backward
  • 异步方法:forward_backward_async

使用

以前后向传播为例,同步方法的写法一般如下:

for i in range(15):
    fwdbwd_future = training_client.forward_backward(...)
    fwdbwd_future = fwdbwd_future.result()

而异步方法的写法如下:

async def main():
    ...
    for i in range(15):
        fwdbwd_future = await training_client.forward_backward_async(...)
        fwdbwd_future = await fwdbwd_future
    ...

if __name__ == "__main__":
    asyncio.run(main())

具体来说, 调用异步方法分为两个阶段:

  1. 提交任务阶段: 通过 await training_client.forward_backward_async(...) 提交计算任务, 调用此方法后会立即返回一个 APIFutureResult 对象, 相当于一个令牌, 代表了这个计算任务的未来结果

    1. 获得 APIFutureResult 对象后, 本地可以继续执行其他的操作, 例如准备调度器, 记录指标等
    2. 计算任务则继续在云端训练服务器上执行
  2. 获取结果阶段: 当需要获取训练结果时, 通过 await fwdbwd_future , 显式阻塞等待训练结果返回

    1. 在调用 await fwdbwd_future 时, 本地会进入阻塞态, 等待直到云端计算任务完成并回调结果后, 才会继续执行后续代码
    2. 通过这种方式, 可以灵活地控制何时需要等待训练结果, 从而提升训练的效率和吞吐量

总结来说,通过异步方法实现的训练流程, 可以协同利用本地和云端的计算资源, 提升整个训练过程的吞吐量。

完整示例

下面是使用异步方法做 SFT 的示例:

import pytrio as trio
import numpy as np
import asyncio

async def main():
    # 1. 与TRIO建立连接
    service_client = trio.ServiceClient()

    # 2. 创建1个训练客户端
    base_model = "Qwen/Qwen3-4B-Instruct-2507"
    training_client = await service_client.create_lora_training_client_async(
        base_model=base_model,
        rank=32,
    )

    # 3. 数据集-让LLM答对什么是trio
    examples = [
        {"input": "what is trio", "output": "trio is emotionmachine's AI Infra products."},
        {"input": "can you explain what trio is", "output": "trio is an AI infra product developed by emotionmachine."},
        {"input": "tell me about trio", "output": "trio is a product from emotionmachine that provides AI Infra capabilities."},
    ]

    # 4. 获取Tokenizer
    print("Loading tokenizer...")
    tokenizer = training_client.get_tokenizer()
    print("Tokenizer finish")

    # 5. 处理数据集,转换为训练需要的格式
    def process_example(example: dict, tokenizer) -> trio.Datum:
        prompt = f"Question: {example['input']}\nAnswer:"

        prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
        prompt_weights = [0] * len(prompt_tokens)

        completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
        completion_weights = [1] * len(completion_tokens)

        tokens = prompt_tokens + completion_tokens
        weights = prompt_weights + completion_weights

        input_tokens = tokens[:-1]
        target_tokens = tokens[1:]
        weights = weights[1:]

        # 转换为trio训练需要的格式
        return trio.Datum(
            model_input=trio.ModelInput.from_ints(tokens=input_tokens),
            loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
        )

    processed_examples = [process_example(ex, tokenizer) for ex in examples]  # 异步打印损失

    # 6. 训练
    print("Start Training")
    print_task_queue = []
    for iter in range(15):
        fwdbwd_future = await training_client.forward_backward_async(processed_examples, "cross_entropy")  # 前向反向计算
        optim_future = await training_client.optim_step_async(trio.AdamParams(learning_rate=1e-4))  # Adam优化器更新

        async def print_loss_async(fwdbwd_future, optim_future, iter: int): # 异步打印损失  
            fwdbwd_result = await fwdbwd_future
            optim_result = await optim_future

            logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
            weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
            loss = -np.dot(logprobs, weights) / weights.sum()
            print(f"Iter{iter+1} Loss per token: {loss:.4f}")
            return loss
        
        print_task_queue.append(print_loss_async(fwdbwd_future, optim_future, iter))
        
    await asyncio.gather(*print_task_queue)    #  等待所有打印任务完成
    
    # 7. 推理与评估
    print("Start Sampling")
    sampling_base_client = await service_client.create_sampling_client_async(base_model=base_model)
    sampling_sft_client = await training_client.save_weights_and_get_sampling_client_async(name='what-is-trio')

    prompt = trio.ModelInput.from_ints(tokenizer.encode("Question: what is trio\nAnswer:"))
    params = trio.SamplingParams(max_tokens=20, temperature=0.0, stop=["\n"])

    future_base = await sampling_base_client.sample_async(prompt=prompt, sampling_params=params, num_samples=1)
    future_sft = await sampling_sft_client.sample_async(prompt=prompt, sampling_params=params, num_samples=1)
    result_base = await future_base
    result_sft = await future_sft

    print("Base Responses:")
    print(f"{repr(result_base.sequences[0].text)}")

    print("SFT Responses:")
    print(f"{repr(result_sft.sequences[0].text)}")

if __name__ == "__main__":
    asyncio.run(main())

已支持的异步方法

ServiceClient

SyncAsync
create_lora_training_clientcreate_lora_training_client_async
create_sampling_clientcreate_sampling_client_async
create_training_client_from_statecreate_training_client_from_state_async
create_training_client_from_state_with_optimizercreate_training_client_from_state_with_optimizer_async

SamplingClient

SyncAsync
samplesample_async
compute_logprobscompute_logprobs_async

TrainingClient

SyncAsync
forwardforward_async
forward_backwardforward_backward_async
forward_backward_customforward_backward_custom_async
optim_stepoptim_step_async
save_statesave_state_async
save_weights_for_samplersave_weights_for_sampler_async
create_sampling_clientcreate_sampling_client_async
save_weights_and_get_sampling_clientsave_weights_and_get_sampling_client_async

On this page