import re
from dataclasses import dataclass
from typing import Any
import tinker
from tinker_cookbook import renderers
from tinker_cookbook.rl.types import StepResult
QUESTION_SUFFIX = " Write your final answer in \\boxed{} format."
def extract_boxed(text: str) -> str | None:
match = re.search(r"\\boxed\s*\{([^{}]+)\}", text)
return match.group(1) if match else None
def answer_matches(completion: str, reference: str) -> bool:
def norm(t: str | None) -> str:
return re.sub(r"\s+", "", (t or "").strip().lower().replace(",", ""))
return norm(extract_boxed(completion)) == norm(extract_boxed(reference) or reference)
class MathEnv:
def __init__(self, question: str, answer: str, renderer: Any) -> None:
self.question = question
self.answer = answer
self.renderer = renderer
async def initial_observation(self):
convo = [{"role": "user", "content": self.question + QUESTION_SUFFIX}]
return self.renderer.build_generation_prompt(convo), self.renderer.get_stop_sequences()
async def step(self, action, *, extra=None):
message, termination = self.renderer.parse_response(action)
text = renderers.get_text_content(message)
correct_format = float(extract_boxed(text) is not None and termination.is_clean)
correct_answer = float(answer_matches(text, self.answer))
return StepResult(
reward=correct_answer - 0.1 * (1.0 - correct_format),
episode_done=True,
next_observation=tinker.ModelInput.empty(),
next_stop_condition=[],
metrics={"correct": correct_answer, "format": correct_format},
logs={},
)