#!/usr/bin/env python3 """FlashRT — Qwen3-8B-NVFP4 OpenAI-compatible HTTP server. Provides /v1/chat/completions backed by the FlashRT NVFP4 path on RTX 5090. Clients targeting the OpenAI API can swap their base URL to this server without code changes. Surface (v1): * /v1/chat/completions (non-stream + stream:true with token-by-token SSE deltas) * /v1/models (returns a single canonical model id) * /health * Tools / function calling via Qwen3 chat-template native support. The model emits {...} blocks; our streamer emits OpenAI-shape tool_calls deltas as the JSON closes. * Sampling: temperature / top_p / top_k / seed / stop / max_tokens. Greedy when temperature==0 (default), else multinomial after top_k+top_p truncation. Limits (v1): * Batch size 1 — concurrent requests are serialised behind a single asyncio lock. Multi-tenant serving belongs in a higher layer. * Single graph-warmed shape ladder is captured at startup; first request at a new (prompt_len) shape pays a small one-time capture cost. Usage:: pip install fastapi uvicorn python examples/qwen3_openai_server.py \\ --checkpoint /path/to/Qwen3-8B-Instruct-NVFP4 \\ --port 8000 \\ --warmup 32:128,128:256,256:256 curl http://localhost:8000/v1/chat/completions \\ -H 'Content-Type: application/json' \\ -d '{"model":"qwen3-8b-nvfp4", "messages":[{"role":"user","content":"Hi"}], "max_tokens":64, "stream":true}' """ from __future__ import annotations import argparse import asyncio import json import logging import os import re import sys import time import uuid from typing import Any, Dict, List, Optional, Tuple logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', ) log = logging.getLogger('qwen3_openai_server') # Qwen3-Instruct tool-call format: model emits # {"name": "fn_name", "arguments": {...}} # anywhere in the assistant turn. We parse incrementally during stream. _TOOL_CALL_OPEN = '' _TOOL_CALL_CLOSE = '' # ──────────────────────────────────────────────────────────────────── # Sampling # ──────────────────────────────────────────────────────────────────── def _sample_token( logits, # (vocab,) bf16/fp32 *, temperature: float, top_p: float, top_k: int, rng=None, ) -> int: """Greedy if temperature == 0 (or top_k == 1), else top-k+top-p multinomial. All operations on-device bf16/fp32. Returns Python int token id. """ import torch if temperature <= 0.0 or top_k == 1: return int(logits.argmax(dim=-1).item()) L = logits.float() / max(temperature, 1e-6) if top_k and 0 < top_k < L.numel(): topv, topi = torch.topk(L, top_k) mask = torch.full_like(L, float('-inf')) mask.scatter_(0, topi, topv) L = mask if 0.0 < top_p < 1.0: sorted_v, sorted_idx = torch.sort(L, descending=True) sorted_p = torch.softmax(sorted_v, dim=-1) cum = sorted_p.cumsum(dim=-1) cutoff_mask = cum > top_p # Keep the first cutoff index — shift right by 1. cutoff_mask[..., 1:] = cutoff_mask[..., :-1].clone() cutoff_mask[..., 0] = False sorted_v[cutoff_mask] = float('-inf') L = torch.full_like(L, float('-inf')) L.scatter_(0, sorted_idx, sorted_v) probs = torch.softmax(L, dim=-1) if rng is not None: return int(torch.multinomial(probs, 1, generator=rng).item()) return int(torch.multinomial(probs, 1).item()) # ──────────────────────────────────────────────────────────────────── # Stop-string + tool-call streaming parser # ──────────────────────────────────────────────────────────────────── class StreamParser: """Incrementally split assistant tokens into: * "content" (free text deltas) * "tool_calls" (parsed JSON objects emitted as OAI tool_call deltas) Plus stop-string detection for early termination. """ def __init__(self, tokenizer, stop_strings: Optional[List[str]] = None): self.tok = tokenizer self._buffer = '' # un-flushed text (may contain partial tags) self._content_pos = 0 # index up to which we have flushed content self._in_tool = False self._tool_buffer = '' self._stop_strings = stop_strings or [] self._tool_calls_emitted: List[dict] = [] # OAI tool_call indexer. self._tool_call_idx = 0 def feed( self, new_token_ids: List[int], *, final: bool = False, ) -> Tuple[str, List[dict], bool]: """Decode the running token list and return (delta_text, new_tool_calls, stop_hit). delta_text: clean content delta (excluding tool-call wrappers). new_tool_calls: list of {index, id, type, function: {name, arguments}} objects newly closed in this feed. stop_hit: True iff any stop string was found. Args: new_token_ids: tokens to append to the running stream (may be empty on the final flush). final: True iff no more tokens will arrive (EOS / max_tokens / stop string already hit upstream). When set, the entire buffer is flushed — no partial-tag hold-back, no max-stop-string-len hold-back. """ # Append decoded fragment. if new_token_ids: try: fragment = self.tok.decode(new_token_ids, skip_special_tokens=False) except Exception: fragment = '' self._buffer += fragment delta_text = '' new_tool_calls: List[dict] = [] stop_hit = False # Stop-string detection — scan the FULL buffer (not just the # flushable head) for any user-supplied stop. If a stop is in # the buffer, truncate the buffer there and mark stop_hit. The # stop string itself is dropped from the output (OpenAI semantics). if self._stop_strings and not self._in_tool: best_idx = -1 for ss in self._stop_strings: idx = self._buffer.find(ss) if idx >= 0 and (best_idx < 0 or idx < best_idx): best_idx = idx if best_idx >= 0: self._buffer = self._buffer[:best_idx] stop_hit = True # The buffer may need a tail hold for two reasons: # (a) the tail of `_buffer` could be a partial `` # opening tag whose final chars haven't streamed yet; # (b) the tail could complete a stop string on the next feed. # Hold-back size = max(len(open_tag), max(stop_string_lens)) - 1. # On `final=True` (or once stop_hit fired) the hold-back is 0. max_stop_len = ( max((len(s) for s in self._stop_strings), default=0) if self._stop_strings else 0 ) hold = ( 0 if (final or stop_hit) else max(len(_TOOL_CALL_OPEN), max_stop_len) - 1 ) while True: if self._in_tool: close_idx = self._buffer.find(_TOOL_CALL_CLOSE) if close_idx < 0: self._tool_buffer += self._buffer self._buffer = '' break self._tool_buffer += self._buffer[:close_idx] self._buffer = self._buffer[close_idx + len(_TOOL_CALL_CLOSE):] self._in_tool = False # Try to parse the tool-call JSON. tc = self._parse_tool_call(self._tool_buffer.strip()) self._tool_buffer = '' if tc is not None: new_tool_calls.append(tc) self._tool_calls_emitted.append(tc) continue open_idx = self._buffer.find(_TOOL_CALL_OPEN) if open_idx < 0: # No open tag in buffer — flush all but the hold-back tail. safe = max(0, len(self._buffer) - hold) if safe > 0: delta_text += self._buffer[:safe] self._buffer = self._buffer[safe:] break # Flush text before the open tag. delta_text += self._buffer[:open_idx] self._buffer = self._buffer[open_idx + len(_TOOL_CALL_OPEN):] self._in_tool = True # loop continues into in_tool branch return delta_text, new_tool_calls, stop_hit def _parse_tool_call(self, raw: str) -> Optional[dict]: """Parse the JSON inside a ... block. Qwen3 emits compact JSON like {"name":"f","arguments":{...}}. Some fine-tunes wrap it in code fences — handle both. """ s = raw.strip() if s.startswith('```'): # strip code fence s = re.sub(r'^```[^\n]*\n', '', s) if s.endswith('```'): s = s[:-3] s = s.strip() try: obj = json.loads(s) except Exception: return None name = obj.get('name') args = obj.get('arguments', obj.get('parameters', {})) if not isinstance(args, str): args = json.dumps(args, ensure_ascii=False) idx = self._tool_call_idx self._tool_call_idx += 1 return { 'index': idx, 'id': f'call_{uuid.uuid4().hex[:24]}', 'type': 'function', 'function': {'name': name, 'arguments': args}, } # ──────────────────────────────────────────────────────────────────── # Engine # ──────────────────────────────────────────────────────────────────── class Qwen3Engine: """Async wrapper around the Qwen3TorchFrontendRtx with streaming.""" def __init__(self, *, checkpoint: str, device: str, model_name: str, max_seq: int, max_q_seq: int): import torch from flash_rt.frontends.torch.qwen3_rtx import ( Qwen3TorchFrontendRtx, ) log.info('loading NVFP4 ckpt from %s ...', checkpoint) t0 = time.perf_counter() self.fe = Qwen3TorchFrontendRtx( checkpoint, device=device, max_seq=max_seq, max_q_seq=max_q_seq, ) log.info('loaded in %.1f s', time.perf_counter() - t0) self.model_name = model_name self.lock = asyncio.Lock() self._torch = torch def warmup(self, shapes: List[Tuple[int, int]]) -> None: """Pre-capture decode + prefill graphs over each (prompt_len, max_tokens) shape so first real requests at those sizes hit warm graphs. """ if not shapes: return torch = self._torch # Pre-capture all prefill bucket graphs once. Cheap (one # capture per bucket, ~5-15 ms each) and enables the # prefill_with_graph fast path for any request whose prompt # length fits the bucket ladder. t0 = time.perf_counter() self.fe.warmup_prefill_graphs() torch.cuda.synchronize() log.info(' warm prefill graphs (%d buckets) in %.1f s', len(self.fe.prefill_buckets), time.perf_counter() - t0) log.info('warmup: %d (prompt, max_tok) shape(s)', len(shapes)) for prompt_len, max_tok in shapes: t0 = time.perf_counter() dummy_text = 'a ' * (max(1, prompt_len) - 1) input_ids = self.fe._tokenizer( dummy_text, return_tensors='pt').input_ids.to('cuda') if input_ids.shape[1] >= prompt_len: input_ids = input_ids[:, :prompt_len] else: pad = torch.full( (1, prompt_len - input_ids.shape[1]), self.fe._tokenizer.pad_token_id or 0, device='cuda', dtype=torch.long, ) input_ids = torch.cat([input_ids, pad], dim=1) self.fe.reset_state() torch.cuda.synchronize() with torch.inference_mode(): # Use the captured prefill graph if the prompt fits a # bucket; falls back to eager forward_prefill_nvfp4 # internally otherwise. Either way leaves the KV cache # populated for the decode warmup that follows. self.fe.prefill_with_graph(input_ids) self.fe.warmup_decode_graphs( prompt_len, prompt_len + max_tok, ) torch.cuda.synchronize() log.info(' warm (P=%d, max_tok=%d) in %.1f s', prompt_len, max_tok, time.perf_counter() - t0) def _render(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]]): """Apply the chat template (with optional tools) to messages. OpenAI lets `assistant.content` be `null` when `tool_calls` is set, but the Qwen3 chat template iterates `content` directly and crashes on `None`. Normalize by mapping `None` → '' before rendering — semantically equivalent (no text content). """ normalized = [] for m in messages: if m.get('content') is None: m = {**m, 'content': ''} normalized.append(m) return self.fe._tokenizer.apply_chat_template( normalized, tools=tools or None, add_generation_prompt=True, tokenize=False, ) async def stream_generate( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]], max_tokens: int, temperature: float, top_p: float, top_k: int, seed: Optional[int], stop: Optional[List[str]], ): """Async generator yielding (kind, payload) events: ('content', str) — content delta ('tool_calls', list[dict]) — parsed tool_call deltas ('finish', reason: str, usage: dict) """ torch = self._torch async with self.lock: prompt = self._render(messages, tools) input_ids = self.fe._tokenizer( prompt, return_tensors='pt').input_ids.to('cuda') P = int(input_ids.shape[1]) rng = None if seed is not None: rng = torch.Generator(device='cuda') rng.manual_seed(int(seed)) parser = StreamParser(self.fe._tokenizer, stop_strings=stop) eos = self.fe._tokenizer.eos_token_id t0 = time.perf_counter() self.fe.reset_state() with torch.inference_mode(): # prefill_with_graph picks the smallest bucket >= P # and replays the captured graph; falls back to # eager forward_prefill_nvfp4 internally if P exceeds # the largest bucket. _logits_buf[:1] holds the next- # token logits either way. self.fe.prefill_with_graph(input_ids) ttft = time.perf_counter() - t0 # Make sure decode graphs over [P, P+max_tokens) are # warm. Must stay inside inference_mode — torch 2.9+ # rejects graph capture that touches inference tensors # (the prefill above marks _logits_buf / KV cache / # hidden buffers as inference tensors) from outside # an inference_mode context. self.fe.warmup_decode_graphs(P, P + max_tokens) new_tokens: List[int] = [] cur_pos = P finish_reason = 'length' for step in range(max_tokens): # Sample from the current logits buffer. tok = _sample_token( self.fe._logits_buf[0], temperature=temperature, top_p=top_p, top_k=top_k, rng=rng, ) new_tokens.append(tok) # EOS check (engine-side, before emitting). if eos is not None and tok == eos: delta, tcs, _ = parser.feed([], final=True) if delta: yield ('content', delta) if tcs: yield ('tool_calls', tcs) finish_reason = ( 'tool_calls' if parser._tool_calls_emitted and not parser._buffer.strip() else 'stop' ) break # Stream parse the new token. delta, tcs, stop_hit = parser.feed([tok]) if delta: yield ('content', delta) if tcs: yield ('tool_calls', tcs) if stop_hit: finish_reason = 'stop' break # Advance KV cache via the warm decode graph. with torch.inference_mode(): self.fe.decode_step_with_graph( torch.tensor([[tok]], device='cuda', dtype=torch.long), cur_pos, ) cur_pos += 1 # Yield to event loop so the SSE chunks can flush. if step % 8 == 0: await asyncio.sleep(0) else: # Loop exhausted max_tokens. # Final flush of any buffered text. delta, tcs, _ = parser.feed([], final=True) if delta: yield ('content', delta) if tcs: yield ('tool_calls', tcs) wall = time.perf_counter() - t0 usage = { 'prompt_tokens': P, 'completion_tokens': len(new_tokens), 'total_tokens': P + len(new_tokens), 'ttft_ms': round(ttft * 1000, 1), 'wall_s': round(wall, 3), 'tok_per_s': round(len(new_tokens) / wall, 1) if wall else 0, } yield ('finish', finish_reason, usage) # ──────────────────────────────────────────────────────────────────── # HTTP layer # ──────────────────────────────────────────────────────────────────── def build_app(engine: 'Qwen3Engine'): from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse app = FastAPI(title='FlashRT Qwen3-8B NVFP4 OpenAI-compatible server') @app.get('/v1/models') async def list_models(): return { 'object': 'list', 'data': [{ 'id': engine.model_name, 'object': 'model', 'created': int(time.time()), 'owned_by': 'flash-vla', }], } @app.get('/health') async def health(): return {'status': 'ok', 'model': engine.model_name} @app.post('/v1/chat/completions') async def chat_completions(req: Dict[str, Any]): messages = req.get('messages') if not isinstance(messages, list) or not messages: raise HTTPException(400, 'messages is required (non-empty list)') for m in messages: role = m.get('role') if role not in ('system', 'user', 'assistant', 'tool'): raise HTTPException(400, f'unsupported role: {role!r}') tools = req.get('tools') # OAI tools spec max_tokens = int(req.get('max_tokens') or 256) stream = bool(req.get('stream', False)) temperature = float(req.get('temperature', 0.0)) top_p = float(req.get('top_p', 1.0)) top_k = int(req.get('top_k', 0)) seed = req.get('seed') stop = req.get('stop') if isinstance(stop, str): stop = [stop] elif stop is None: stop = [] elif not isinstance(stop, list): raise HTTPException(400, 'stop must be string or list') completion_id = f'chatcmpl-{uuid.uuid4().hex[:24]}' created = int(time.time()) if not stream: content = '' tool_calls: List[dict] = [] finish = 'stop' usage: dict = {} async for ev in engine.stream_generate( messages, tools, max_tokens, temperature, top_p, top_k, seed, stop, ): if ev[0] == 'content': content += ev[1] elif ev[0] == 'tool_calls': tool_calls.extend(ev[1]) elif ev[0] == 'finish': _, finish, usage = ev msg: dict = {'role': 'assistant', 'content': content or None} if tool_calls: msg['tool_calls'] = tool_calls log.info( 'non-stream done: %s -> %s tok in %ss (%s tok/s)', usage.get('prompt_tokens'), usage.get('completion_tokens'), usage.get('wall_s'), usage.get('tok_per_s'), ) return { 'id': completion_id, 'object': 'chat.completion', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'message': msg, 'finish_reason': finish, }], 'usage': usage, } # ── Streaming SSE ── async def gen(): # Emit role first. first = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {'role': 'assistant', 'content': ''}, 'finish_reason': None, }], } yield f'data: {json.dumps(first)}\n\n' tc_seen = False async for ev in engine.stream_generate( messages, tools, max_tokens, temperature, top_p, top_k, seed, stop, ): if ev[0] == 'content': chunk = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {'content': ev[1]}, 'finish_reason': None, }], } yield f'data: {json.dumps(chunk)}\n\n' elif ev[0] == 'tool_calls': tc_seen = True for tc in ev[1]: chunk = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {'tool_calls': [tc]}, 'finish_reason': None, }], } yield f'data: {json.dumps(chunk)}\n\n' elif ev[0] == 'finish': _, finish, usage = ev last = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {}, 'finish_reason': ( 'tool_calls' if tc_seen and finish in ('stop', 'length') else finish ), }], 'usage': usage, } yield f'data: {json.dumps(last)}\n\n' yield 'data: [DONE]\n\n' log.info( 'stream done: %s -> %s tok in %ss (%s tok/s)', usage.get('prompt_tokens'), usage.get('completion_tokens'), usage.get('wall_s'), usage.get('tok_per_s'), ) return return StreamingResponse(gen(), media_type='text/event-stream') return app def main(): p = argparse.ArgumentParser() p.add_argument('--checkpoint', required=True, help='Path to NVFP4 ckpt dir.') p.add_argument('--port', type=int, default=8000) p.add_argument('--host', default='0.0.0.0') p.add_argument('--max-seq', type=int, default=2048) p.add_argument('--max-q-seq', type=int, default=128, help='Max prompt prefill length (in tokens).') p.add_argument('--device', default='cuda:0') p.add_argument('--model-name', default='qwen3-8b-nvfp4') p.add_argument('--warmup', default='32:128,128:256', help='Comma-separated "P:max_tok" shapes to warm.') args = p.parse_args() warm: List[Tuple[int, int]] = [] for spec in args.warmup.split(','): spec = spec.strip() if not spec: continue try: pl, mt = spec.split(':') warm.append((int(pl), int(mt))) except ValueError: sys.exit(f'invalid --warmup spec: {spec!r}') try: import uvicorn except ImportError: sys.exit('uvicorn is required: pip install fastapi uvicorn') engine = Qwen3Engine( checkpoint=args.checkpoint, device=args.device, model_name=args.model_name, max_seq=args.max_seq, max_q_seq=args.max_q_seq, ) if warm: engine.warmup(warm) app = build_app(engine) uvicorn.run(app, host=args.host, port=args.port, log_level='warning') if __name__ == '__main__': main()