import argparse
import asyncio
from tqdm import tqdm
from openai import DefaultAioHttpClient
from openai import AsyncOpenAI
async def main(
num_requests: int,
input_tokens: int,
max_output_tokens: int,
model: str,
):
base_url = "https://api.sailresearch.com/v1"
api_key = "YOUR_KEY_HERE"
async with AsyncOpenAI(
base_url=base_url,
api_key=api_key,
http_client=DefaultAioHttpClient(),
) as client:
# List supported models
models = await client.models.list()
supported_models = [m.id for m in models.data]
print("Supported Models:")
print(supported_models)
# Submit requests concurrently
sem = asyncio.Semaphore(200)
pbar_submit = tqdm(total=num_requests, desc="Submitting requests")
async def submit(i):
filler = ""
if input_tokens > 0:
filler = " " + ("word " * input_tokens)
content = f"TASK {i}: What is a fun fact about the number {i}? Then, find the word at index {i} in the following sequence of words: {filler}"
async with sem:
response = await client.responses.create(
model=model,
input=[{"role": "user", "content": content}],
max_output_tokens=max_output_tokens,
background=True
)
pbar_submit.update(1)
return response.id
response_ids = await asyncio.gather(*[submit(i) for i in range(num_requests)])
pbar_submit.close()
response_ids = list(response_ids)
print(f"Created {len(response_ids)} response IDs")
# Poll for completions
completed = {}
poll_sem = asyncio.Semaphore(200)
pbar = tqdm(total=len(response_ids), desc="Polling responses")
async def poll(response_id):
for _ in range(3600): # 1 hour timeout
async with poll_sem:
response = await client.responses.retrieve(response_id)
if response.status == "completed":
completed[response_id] = response
pbar.update(1)
return
await asyncio.sleep(1)
print(f"Timed out waiting for {response_id}")
await asyncio.gather(*[poll(rid) for rid in response_ids])
pbar.close()
for response_id in response_ids:
resp = completed[response_id]
output_text = "".join(
c.text
for item in resp.output
for c in getattr(item, "content", []) or []
if getattr(c, "type", None) == "output_text"
)
print(f"\n\n{response_id} completed with output:\n{output_text}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-requests", type=int, default=20000)
parser.add_argument("--input-tokens", type=int, default=50)
parser.add_argument("--max-output-tokens", type=int, default=4000)
parser.add_argument("--model", type=str, default="zai-org/GLM-5")
args = parser.parse_args()
asyncio.run(
main(
num_requests=args.num_requests,
input_tokens=args.input_tokens,
max_output_tokens=args.max_output_tokens,
model=args.model,
)
)