Skip to main content
This guide walks through a small GRPO-style reinforcement-learning training loop that trains a LoRA adapter with Tinker while sampling every rollout from Sail. Tinker owns the optimizer; Sail serves each fresh checkpoint through SailTokenCompleter, so your rollouts run on Sail’s inference fleet. The example fine-tunes Kimi K2.6 to solve grade-school math word problems, rewarding answers that land in \boxed{}.

Prerequisites

  • Python 3.11+
  • A Sail API key and a Tinker API key
  • The packages below
pip install sail-sdk tinker tinker-cookbook datasets
export SAIL_API_KEY=sk_your_key_here
export TINKER_API_KEY=your_tinker_key
export HF_TOKEN=your_huggingface_token

How one step fits together

Each training step is the same five moves:
  1. Snapshot the current LoRA weights as a Tinker sampler checkpoint.
  2. Resolve a signed URL to that checkpoint and wrap it in a SailTokenCompleter.
  3. Roll out a batch of grouped completions through that completer on Sail.
  4. Score the completions, turn them into advantages and Tinker training data.
  5. Train one optimizer step on the Tinker client, then repeat with the updated weights.

1. Score a completion

An environment renders one prompt and scores one sampled completion. tinker-cookbook calls initial_observation to get the prompt tokens (and stop sequences), then step to score the tokens Sail sampled. The reward here is 1 for a correct boxed answer, with a small penalty for missing the \boxed{} format.
import re
from dataclasses import dataclass
from typing import Any

import tinker
from tinker_cookbook import renderers
from tinker_cookbook.rl.types import StepResult

QUESTION_SUFFIX = " Write your final answer in \\boxed{} format."


def extract_boxed(text: str) -> str | None:
    match = re.search(r"\\boxed\s*\{([^{}]+)\}", text)
    return match.group(1) if match else None


def answer_matches(completion: str, reference: str) -> bool:
    def norm(t: str | None) -> str:
        return re.sub(r"\s+", "", (t or "").strip().lower().replace(",", ""))

    return norm(extract_boxed(completion)) == norm(extract_boxed(reference) or reference)


class MathEnv:
    def __init__(self, question: str, answer: str, renderer: Any) -> None:
        self.question = question
        self.answer = answer
        self.renderer = renderer

    async def initial_observation(self):
        convo = [{"role": "user", "content": self.question + QUESTION_SUFFIX}]
        return self.renderer.build_generation_prompt(convo), self.renderer.get_stop_sequences()

    async def step(self, action, *, extra=None):
        message, termination = self.renderer.parse_response(action)
        text = renderers.get_text_content(message)
        correct_format = float(extract_boxed(text) is not None and termination.is_clean)
        correct_answer = float(answer_matches(text, self.answer))
        return StepResult(
            reward=correct_answer - 0.1 * (1.0 - correct_format),
            episode_done=True,
            next_observation=tinker.ModelInput.empty(),
            next_stop_condition=[],
            metrics={"correct": correct_answer, "format": correct_format},
            logs={},
        )
A group builder fans one prompt out into group_size environments.
@dataclass(frozen=True)
class MathGroupBuilder:
    row: dict
    renderer: Any
    group_size: int

    async def make_envs(self):
        return [
            MathEnv(self.row["question"], self.row["answer"], self.renderer)
            for _ in range(self.group_size)
        ]

    # The reward lives in MathEnv.step, so there is no extra group-level bonus.
    async def compute_group_rewards(self, trajectory_group, env_group):
        return [(0.0, {}) for _ in trajectory_group]

    async def cleanup(self):
        return None

    def logging_tags(self):
        return ["math"]

2. Serve the latest checkpoint on Sail

After each Tinker step, save the weights as a sampler checkpoint, resolve a signed archive URL, and build a SailTokenCompleter pointed at it. The adapter_config is the PEFT config matching the LoRA Tinker is training. In our example, examples/tinker_adapter_config.example.json is a ready-made one for Kimi K2.6 at rank 32. Passing ttl_seconds tells Tinker to expire the checkpoint, so per-step checkpoints clean themselves up instead of piling up.
import sail


async def sail_policy(training_client, service_client, adapter_config, name):
    save_future = await training_client.save_weights_for_sampler_async(
        name, ttl_seconds=3600
    )
    save_result = await save_future
    signed_url = await sail.get_tinker_checkpoint_signed_url_async(
        service_client, save_result.path, ttl_seconds=3600
    )
    return sail.SailTokenCompleter(
        model="moonshotai/Kimi-K2.6",
        max_tokens=1024,
        temperature=1.0,
        completion_window="priority",
        tinker_lora_signed_url=signed_url,
        adapter_config=adapter_config,
        tinker_lora_name=name,
    )

3. The training loop

Wire it together: load the data, then each step snapshot → roll out on Sail → train.
import asyncio
import json
import random
import time

from datasets import load_dataset
from tinker_cookbook.rl.data_processing import assemble_training_data, compute_advantages
from tinker_cookbook.rl.rollouts import do_group_rollout
from tinker_cookbook.rl.train import train_step
from tinker_cookbook.tokenizer_utils import get_tokenizer

MODEL = "moonshotai/Kimi-K2.6"


def mean_reward(groups):
    rewards = [float(r) for group in groups for r in group.get_total_rewards()]
    return sum(rewards) / len(rewards) if rewards else 0.0


def split_degenerate_groups(groups):
    kept = []
    degenerate = []
    for group in groups:
        rewards = [float(r) for r in group.get_total_rewards()]
        if rewards and len(set(rewards)) == 1:
            degenerate.append(group)
        else:
            kept.append(group)
    return kept, degenerate


async def main():
    adapter_config = json.load(open("examples/tinker_adapter_config.example.json"))

    service_client = tinker.ServiceClient()
    training_client = await service_client.create_lora_training_client_async(
        base_model=MODEL, rank=32
    )
    renderer = renderers.get_renderer("kimi_k26", tokenizer=get_tokenizer(MODEL))

    ds = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
    train_rows = [
        {"question": r["question"], "answer": r["answer"]} for r in ds.select(range(2000))
    ]

    run_id = f"orca-math-{int(time.time())}"
    group_size = 4
    groups_per_step = 16  # 16 prompts x 4 samples = 64 rollouts per step

    for step in range(7):
        # 1-2. Snapshot the current weights and serve them on Sail.
        policy = await sail_policy(
            training_client, service_client, adapter_config, f"{run_id}-step-{step}"
        )

        # 3. Roll out a batch of grouped completions through Sail.
        rows = random.sample(train_rows, groups_per_step)
        builders = [MathGroupBuilder(row=r, renderer=renderer, group_size=group_size) for r in rows]
        groups = await asyncio.gather(*(do_group_rollout(b, policy) for b in builders))

        # 4-5. GRPO-style advantages, then one Tinker optimizer step.
        reward = mean_reward(groups)
        training_groups, degenerate_groups = split_degenerate_groups(groups)
        degenerate_pct = 100.0 * len(degenerate_groups) / len(groups) if groups else 0.0
        if not training_groups:
            raise RuntimeError("all rollout groups were degenerate")

        advantages = compute_advantages(training_groups)
        data, _ = assemble_training_data(training_groups, advantages)
        await train_step(
            data_D=data,
            training_client=training_client,
            learning_rate=1e-5,
            num_substeps=1,
            loss_fn="importance_sampling",
            metrics={},
        )
        print(
            f"Step {step:2d} | reward: {reward:.3f} | "
            f"degenerate: {degenerate_pct:.0f}% | datums: {len(data)}"
        )


asyncio.run(main())
That’s the whole loop. Tinker holds the optimizer state and applies each update; Sail samples every rollout from the checkpoint you just wrote. To scale the rollout batch, raise group_size and groups_per_step, and Sail runs the completions concurrently. To watch the reward climb, log the metrics returned by compute_advantages/train_step.

Next steps

  • Tinker — the SailTokenCompleter reference: parameters, LoRA modes, and constraints.
  • LoRAs — upload and serve a PEFT adapter directly, without a Tinker training loop.
  • Completion Windows — control the latency/cost tier your rollouts run on.