#!/usr/bin/env python3 """ FlashRT — Qwen3.6-27B NVFP4 OpenAI-compatible HTTP server. Serves the /v1/chat/completions endpoint backed by the FlashRT NVFP4 inference path. Clients targeting the OpenAI API can swap their base URL to this server without code changes. Usage: pip install fastapi uvicorn # Required env: paired FP8 ckpt dir for the MTP head. export FLASHRT_QWEN36_MTP_CKPT_DIR=/path/to/qwen36_fp8_ckpt python examples/qwen36_openai_server.py \\ --checkpoint /path/to/qwen36_nvfp4 \\ --max-seq 32768 \\ --port 8000 \\ --K 6 \\ --warmup-preset auto # Startup warmup pre-captures CUDA Graphs for bucketed # (prompt_len:max_tokens) shapes so the FIRST real request usually # hits warm graphs. Short buckets run a dummy generation; long # buckets default to decode-graph-only warmup, avoiding minutes of # synthetic 200K/256K prompt prefill. 128K+ buckets warm every early # decode position by default; tune with # FLASHRT_QWEN36_LONG_WARMUP_STRIDE/MAX_GRAPHS. Add --warmup # "32768:64,65536:64,131072:64,204800:64,262144:16" for an explicit # serving envelope, or --warmup-preset none to skip. # Test (non-streaming): curl http://localhost:8000/v1/chat/completions \\ -H "Content-Type: application/json" \\ -d '{ "model": "qwen3.6-27b-nvfp4", "messages": [{"role": "user", "content": "Hello!"}], "max_tokens": 128, "stream": false }' # OpenAI Python client: # from openai import OpenAI # client = OpenAI(base_url="http://localhost:8000/v1", api_key="-") # resp = client.chat.completions.create( # model="qwen3.6-27b-nvfp4", # messages=[{"role": "user", "content": "Hi"}], # max_tokens=128, # ) # # Function calling uses the Qwen chat-template native tool format. # Pass OpenAI-shaped "tools"; the server parses model-emitted # {...} blocks into OpenAI "tool_calls". Limits in v1 (see docs/qwen36_usage.md): * Batch size 1 (concurrent requests are serialized; do not run multiple workers against one GPU). * Greedy decode only — temperature / top_p / top_k / n / seed / stop / logit_bias are accepted but ignored. * Qwen thinking mode is disabled by default. Pass "enable_thinking": true in the JSON body to opt in. * stream=True returns one chunk with the full response (true token-by-token streaming requires a frontend modification). """ from __future__ import annotations import argparse import asyncio import json import logging import os import re import sys import time import uuid from typing import Any, Dict, List, Optional, Tuple # ──────────────────────────────────────────────────────────────────── # Logger # ──────────────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', ) log = logging.getLogger('qwen36_openai_server') # Qwen tool-call format: # {"name": "fn_name", "arguments": {...}} _TOOL_CALL_OPEN = '' _TOOL_CALL_CLOSE = '' class ToolCallParser: """Split Qwen tool-call blocks from a complete assistant response.""" def __init__(self): self._tool_call_idx = 0 def parse(self, text: str) -> tuple[str, List[dict]]: content_parts: List[str] = [] tool_calls: List[dict] = [] pos = 0 while True: open_idx = text.find(_TOOL_CALL_OPEN, pos) if open_idx < 0: content_parts.append(text[pos:]) break content_parts.append(text[pos:open_idx]) raw_start = open_idx + len(_TOOL_CALL_OPEN) close_idx = text.find(_TOOL_CALL_CLOSE, raw_start) if close_idx < 0: content_parts.append(text[open_idx:]) break tc = self._parse_tool_call(text[raw_start:close_idx].strip()) if tc is not None: tool_calls.append(tc) pos = close_idx + len(_TOOL_CALL_CLOSE) return ''.join(content_parts), tool_calls def _parse_tool_call(self, raw: str) -> Optional[dict]: s = raw.strip() if s.startswith('```'): s = re.sub(r'^```[^\n]*\n', '', s) if s.endswith('```'): s = s[:-3] s = s.strip() try: obj = json.loads(s) except Exception: return None name = obj.get('name') args = obj.get('arguments', obj.get('parameters', {})) if not isinstance(args, str): args = json.dumps(args, ensure_ascii=False) idx = self._tool_call_idx self._tool_call_idx += 1 return { 'index': idx, 'id': f'call_{uuid.uuid4().hex[:24]}', 'type': 'function', 'function': {'name': name, 'arguments': args}, } # ──────────────────────────────────────────────────────────────────── # Frontend wrapper # ──────────────────────────────────────────────────────────────────── class Qwen36Engine: """Thin wrapper around Qwen36TorchFrontendRtx with chat-template rendering and a single-request lock (batch=1 only).""" def __init__(self, checkpoint: str, *, K: int, max_seq: int, device: str, model_name: str): import torch from flash_rt.frontends.torch.qwen36_rtx import ( Qwen36TorchFrontendRtx, ) log.info('loading NVFP4 ckpt from %s ...', checkpoint) t0 = time.perf_counter() self.fe = Qwen36TorchFrontendRtx( checkpoint, quant='nvfp4', device=device, max_seq=max_seq, ) log.info('loaded in %.1f s', time.perf_counter() - t0) self.K = int(K) self.model_name = model_name self.lock = asyncio.Lock() self._torch = torch if self.fe._weights.ptrs.get('mtp') is None: log.warning( 'MTP head not loaded (FLASHRT_QWEN36_MTP_CKPT_DIR ' 'unset?) — speculative decode disabled. The server ' 'will fall back to single-token decode (~36 tok/s).') self.spec_enabled = False else: self.spec_enabled = True log.info('MTP head loaded; spec K=%d enabled', self.K) log.info( 'long-ctx route=%s tq_stage_layers=%s tq_stage_cap=%s ' 'tq_hot_layers=%s tq_hot_cap=%s', getattr(self.fe, '_long_kv_cache_mode', 'bf16'), getattr(self.fe, '_tq_per_layer_stage_layers', 'n/a'), getattr(self.fe, '_tq_per_layer_stage_cap', 'n/a'), getattr(self.fe, '_tq_hot_stage_layers', 'n/a'), getattr(self.fe, '_tq_hot_stage_cap', 'n/a'), ) def _dummy_input_ids(self, prompt_len: int): """Build exact-length CUDA token ids without tokenizer drift.""" torch = self._torch prompt_len = int(prompt_len) if prompt_len <= 0: raise ValueError(f'prompt_len must be >0, got {prompt_len}') token_ids = self.fe._tokenizer( ' warmup', add_special_tokens=False).input_ids token = int(token_ids[0] if token_ids else 1) return torch.full( (1, prompt_len), token, device='cuda', dtype=torch.long) def _effective_long_k(self, prompt_len: int) -> int: """Mirror frontend long TQ default K policy for logging.""" if hasattr(self.fe, '_long_tq_effective_k'): return self.fe._long_tq_effective_k(prompt_len, self.K) return min(self.K, 6) def warmup(self, shapes: List[Tuple[int, int]]) -> None: """Pre-capture CUDA Graphs for typical (prompt_len, max_tokens) shapes. Short-context buckets run dummy generations; long-context buckets default to decode-graph-only warmup to avoid paying full synthetic prompt prefill at startup. Without this, the FIRST request at each new (prompt_len, max_tokens) shape pays CUDA Graph capture latency. Args: shapes: list of (prompt_len, max_tokens) tuples to pre-warm. Defaults to a single (64, 256) shape if empty. """ if not shapes: return torch = self._torch log.info('warmup: pre-capturing graphs for %d shape(s) ...', len(shapes)) long_graph_only = ( os.environ.get( 'FLASHRT_QWEN36_SERVER_LONG_WARMUP', 'graphs').lower() in ('graphs', 'graph', 'decode_graphs', '1', 'true', 'yes') ) for prompt_len, max_tok in shapes: if prompt_len + max_tok > self.fe._user_max_seq: log.warning( ' skip warmup shape=(prompt=%d, max_tok=%d): ' 'exceeds max_seq=%d', prompt_len, max_tok, self.fe._user_max_seq) continue t0 = time.perf_counter() torch.cuda.synchronize() if hasattr(self.fe, '_should_use_long_ctx_route'): is_long = self.fe._should_use_long_ctx_route( prompt_len, max_tok) else: route_min = getattr( self.fe, '_long_ctx_route_min_seq', getattr(self.fe, '_short_ctx_spec_max_seq', 2048)) bf16_cap = getattr( self.fe, '_short_ctx_spec_max_seq', route_min) is_long = ( getattr(self.fe, '_long_ctx_mode', False) and (prompt_len >= route_min or prompt_len + max_tok > bf16_cap) ) if self.spec_enabled and is_long and long_graph_only: warmed = self.fe.warmup_long_ctx_decode_graphs( [(prompt_len, max_tok)], K=self.K) torch.cuda.synchronize() log.info( ' warmup shape=(prompt=%d, max_tok=%d, eff_K=%s) ' 'decode-graphs=%d in %.1f s', prompt_len, max_tok, self._effective_long_k(prompt_len), len(warmed), time.perf_counter() - t0) continue input_ids = self._dummy_input_ids(prompt_len) if self.spec_enabled: _ = self.fe.generate_own_speculative_KN_nvfp4( input_ids, max_new_tokens=max_tok, K=self.K) else: _ = self._single_token_decode(input_ids, max_tok) torch.cuda.synchronize() log.info( ' warmup shape=(prompt=%d, max_tok=%d, eff_K=%s) ' 'in %.1f s', prompt_len, max_tok, self._effective_long_k(prompt_len) if getattr(self.fe, '_long_ctx_mode', False) else self.K, time.perf_counter() - t0) log.info('warmup done — warmed buckets should avoid most ' 'first-request CUDA Graph capture latency') def _render_chat( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]], *, enable_thinking: bool = False, ) -> str: """Apply Qwen's chat template to a list of messages.""" normalized = [] for m in messages: if m.get('content') is None: m = {**m, 'content': ''} normalized.append(m) return self.fe._tokenizer.apply_chat_template( normalized, tools=tools or None, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) async def generate( self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]], max_tokens: int, *, enable_thinking: bool = False, ) -> Dict[str, Any]: """Run one chat-completion. Returns a dict with the new text and basic timing/stat fields.""" torch = self._torch async with self.lock: prompt = self._render_chat( messages, tools, enable_thinking=enable_thinking) input_ids = self.fe._tokenizer( prompt, return_tensors='pt').input_ids.cuda() prompt_len = int(input_ids.shape[1]) torch.cuda.synchronize() t0 = time.perf_counter() if self.spec_enabled: out = self.fe.generate_own_speculative_KN_nvfp4( input_ids, max_new_tokens=max_tokens, K=self.K, ) else: out = self._single_token_decode(input_ids, max_tokens) torch.cuda.synchronize() wall_s = time.perf_counter() - t0 new_tokens = out[0, prompt_len:].tolist() raw_text = self.fe._tokenizer.decode( new_tokens, skip_special_tokens=False) for boundary in ('<|im_start|>', '<|im_end|>'): idx = raw_text.find(boundary) if idx >= 0: raw_text = raw_text[:idx] for special in ( self.fe._tokenizer.eos_token, self.fe._tokenizer.pad_token, ): if special: raw_text = raw_text.replace(special, '') if not enable_thinking: raw_text = re.sub( r'.*?\s*', '', raw_text, flags=re.DOTALL, ) raw_text = raw_text.replace('', '').replace( '', '') text, tool_calls = ToolCallParser().parse(raw_text) text = text.strip() completion_tokens = len(new_tokens) prefill_ms = float(getattr(self.fe, '_long_ctx_prefill_ms', 0.0) or 0.0) decode_ms = float(getattr(self.fe, '_long_ctx_decode_ms', 0.0) or 0.0) decode_tok_per_s = ( completion_tokens * 1000.0 / decode_ms if decode_ms > 0 else 0.0 ) e2e_tok_per_s = completion_tokens / wall_s if wall_s else 0.0 return { 'text': text, 'tool_calls': tool_calls, 'prompt_tokens': prompt_len, 'completion_tokens': completion_tokens, 'prefill_ms': prefill_ms, 'decode_ms': decode_ms, 'wall_s': wall_s, 'decode_tok_per_s': decode_tok_per_s, 'e2e_tok_per_s': e2e_tok_per_s, 'route': getattr(self.fe, '_long_ctx_route', 'unknown'), } def _single_token_decode(self, input_ids, max_tokens): """Fallback when MTP is not loaded. Slower path (~36 tok/s).""" torch = self._torch fe = self.fe fe.reset_state() if not hasattr(fe, '_rope_cos_table'): fe._build_rope_table() prompt_len = int(input_ids.shape[1]) generated = list(input_ids[0].tolist()) cur_pos = 0 with torch.no_grad(): for p in range(prompt_len): fe._static_token_id.copy_(input_ids[:, p:p + 1]) cos, sin = fe._rope_cos_sin(cur_pos) fe.forward_own_decode_nvfp4( fe._static_token_id, cos, sin, cur_pos) cur_pos += 1 for _ in range(max_tokens): tok = fe._logits_buf.argmax( dim=-1, keepdim=True).view(1, 1) generated.append(int(tok.item())) fe._static_token_id.copy_(tok) cos, sin = fe._rope_cos_sin(cur_pos) fe.forward_own_decode_nvfp4( fe._static_token_id, cos, sin, cur_pos) cur_pos += 1 return torch.tensor([generated], device='cuda') # ──────────────────────────────────────────────────────────────────── # OpenAI-compatible HTTP layer (FastAPI) # ──────────────────────────────────────────────────────────────────── def build_app(engine: Qwen36Engine): from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse app = FastAPI(title='FlashRT Qwen3.6 NVFP4 OpenAI-compatible server') @app.get('/v1/models') async def list_models(): return { 'object': 'list', 'data': [{ 'id': engine.model_name, 'object': 'model', 'created': int(time.time()), 'owned_by': 'flash-rt', }], } @app.post('/v1/chat/completions') async def chat_completions(req: Dict[str, Any]): messages = req.get('messages') if not messages or not isinstance(messages, list): raise HTTPException(400, 'messages is required') tools = req.get('tools') max_tokens = int(req.get('max_tokens') or 256) stream = bool(req.get('stream', False)) enable_thinking = bool(req.get('enable_thinking', False)) # Validate roles — Qwen template accepts OpenAI chat roles. for m in messages: role = m.get('role') if role not in ('system', 'user', 'assistant', 'tool'): raise HTTPException( 400, f'unsupported role: {role!r}') content = m.get('content') if content is None and role == 'assistant': continue if not isinstance(content, str): raise HTTPException( 400, 'message.content must be a string') result = await engine.generate( messages, tools, max_tokens, enable_thinking=enable_thinking, ) completion_id = f'chatcmpl-{uuid.uuid4().hex[:24]}' created = int(time.time()) log.info( 'chat.completions: prompt=%d completion=%d route=%s ' 'prefill=%.1fms + decode=%.1fms wall=%.1fms ' 'decode_tok/s=%.1f e2e_tok/s=%.1f', result['prompt_tokens'], result['completion_tokens'], result['route'], result['prefill_ms'], result['decode_ms'], result['wall_s'] * 1000.0, result['decode_tok_per_s'], result['e2e_tok_per_s'], ) usage = { 'prompt_tokens': result['prompt_tokens'], 'completion_tokens': result['completion_tokens'], 'total_tokens': (result['prompt_tokens'] + result['completion_tokens']), } if stream: # We don't have token-by-token streaming yet (v1 limit); # emit the full message in one delta then [DONE]. Clients # that target streaming will see one big chunk. async def gen(): first = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': { 'role': 'assistant', }, 'finish_reason': None, }], } if result['text']: first['choices'][0]['delta']['content'] = result['text'] yield f'data: {json.dumps(first)}\n\n' for tc in result['tool_calls']: chunk = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {'tool_calls': [tc]}, 'finish_reason': None, }], } yield f'data: {json.dumps(chunk)}\n\n' last = { 'id': completion_id, 'object': 'chat.completion.chunk', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'delta': {}, 'finish_reason': ( 'tool_calls' if result['tool_calls'] else 'stop' ), }], 'usage': usage, } yield f'data: {json.dumps(last)}\n\n' yield 'data: [DONE]\n\n' return StreamingResponse(gen(), media_type='text/event-stream') content = ( result['text'] if result['text'] or not result['tool_calls'] else None ) message = { 'role': 'assistant', 'content': content, } if result['tool_calls']: message['tool_calls'] = result['tool_calls'] return { 'id': completion_id, 'object': 'chat.completion', 'created': created, 'model': engine.model_name, 'choices': [{ 'index': 0, 'message': message, 'finish_reason': ( 'tool_calls' if result['tool_calls'] else 'stop' ), }], 'usage': usage, } @app.get('/health') async def health(): return {'status': 'ok', 'model': engine.model_name, 'spec_enabled': engine.spec_enabled, 'K': engine.K, 'long_kv_cache': getattr( engine.fe, '_long_kv_cache_mode', 'bf16'), 'tq_stage_layers': getattr( engine.fe, '_tq_per_layer_stage_layers', None), 'tq_stage_cap': getattr( engine.fe, '_tq_per_layer_stage_cap', None), 'tq_hot_stage_layers': getattr( engine.fe, '_tq_hot_stage_layers', None), 'tq_hot_stage_cap': getattr( engine.fe, '_tq_hot_stage_cap', None)} return app def _parse_warmup_shapes(spec_csv: str) -> List[Tuple[int, int]]: shapes: List[Tuple[int, int]] = [] if not spec_csv.strip(): return shapes for spec in spec_csv.split(','): spec = spec.strip() if not spec: continue try: pl, mt = spec.split(':') shapes.append((int(pl), int(mt))) except ValueError: sys.exit(f'invalid --warmup spec: {spec!r} ' '(expected "prompt_len:max_tokens")') return shapes def _warmup_preset_shapes(preset: str, max_seq: int) -> List[Tuple[int, int]]: """Return startup graph-warm buckets that fit inside max_seq. The default buckets use 64 generated tokens because that captures the common early decode range where user-visible cold latency is most painful without making server startup spend minutes on long synthetic short-prompt completions. Add explicit --warmup entries for larger completion caps. """ preset = (preset or 'auto').lower() if preset in ('none', 'off', 'false', '0'): return [] if preset not in ('auto', 'short', 'long', 'all'): sys.exit( f'invalid --warmup-preset {preset!r}; expected ' 'auto, short, long, all, or none') short = [(8, 64), (128, 64), (512, 64), (1024, 64)] long = [ (2048, 64), (4096, 64), (8192, 64), (16384, 64), (32768, 64), (65536, 64), (131072, 64), (204800, 64), (262144, 16), ] if preset == 'short': candidates = short elif preset == 'long': candidates = long elif preset == 'all': candidates = short + long else: candidates = short + long max_seq = int(max_seq) return [(p, n) for p, n in candidates if p + n <= max_seq] def _dedupe_shapes(shapes: List[Tuple[int, int]]) -> List[Tuple[int, int]]: out: List[Tuple[int, int]] = [] seen = set() for shape in shapes: if shape not in seen: out.append(shape) seen.add(shape) return out def main(): p = argparse.ArgumentParser() p.add_argument('--checkpoint', required=True, help='Path to NVFP4 main ckpt (compressed-tensors).') p.add_argument('--port', type=int, default=8000) p.add_argument('--host', default='0.0.0.0') p.add_argument('--K', type=int, default=6, help='MTP draft chain length per spec cycle. ' 'Default 6 (peak for short generations on RTX 5090).') p.add_argument('--max-seq', type=int, default=32768, help='KV cache + scratch dim. Increase for long ctx.') p.add_argument('--device', default='cuda:0') p.add_argument('--model-name', default='qwen3.6-27b-nvfp4', help='Identifier returned by /v1/models and echoed ' 'in completion responses.') p.add_argument( '--warmup-preset', default='auto', help='Startup graph warmup preset: auto, short, long, all, or ' 'none. auto warms short buckets plus long buckets that fit in ' '--max-seq. Use all with --max-seq 262208+ to include 256K.') p.add_argument( '--warmup', default='', help='Comma-separated list of "prompt_len:max_tokens" shapes ' 'to additionally pre-capture at startup. These are appended to ' '--warmup-preset. Set --warmup-preset none and --warmup "" to ' 'skip all startup warmup.') args = p.parse_args() warmup_shapes = _dedupe_shapes( _warmup_preset_shapes(args.warmup_preset, args.max_seq) + _parse_warmup_shapes(args.warmup) ) if 'FLASHRT_QWEN36_MTP_CKPT_DIR' not in os.environ: log.warning( 'FLASHRT_QWEN36_MTP_CKPT_DIR is not set — speculative ' 'decode will be disabled and tok/s will fall to ~36. See ' 'docs/qwen36_usage.md for the FP8 ckpt requirement.') try: import uvicorn except ImportError: sys.exit('uvicorn is required: pip install uvicorn fastapi') engine = Qwen36Engine( checkpoint=args.checkpoint, K=args.K, max_seq=args.max_seq, device=args.device, model_name=args.model_name, ) if warmup_shapes: engine.warmup(warmup_shapes) app = build_app(engine) uvicorn.run(app, host=args.host, port=args.port, log_level='warning') if __name__ == '__main__': main()