{ "cells": [ { "cell_type": "markdown", "id": "c109c0e7-1aad-42ab-88d8-0990559b59e5", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Reasoning Model (From Scratch) book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/reasoning-from-scratch\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "88c613ef-f4e5-49c3-b19d-3cf36dce0bf1", "metadata": {}, "source": [ "# Appendix C: Qwen3 LLM Source Code" ] }, { "cell_type": "markdown", "id": "90adc11a-3ef1-45b9-bb5e-162b7db3817c", "metadata": {}, "source": [ "Packages that are being used in this notebook:" ] }, { "cell_type": "code", "execution_count": 1, "id": "f9a28d80-f79b-4b86-a788-36ab97645e31", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reasoning_from_scratch version: 0.1.13\n", "torch version: 2.10.0\n", "tokenizers version: 0.22.2\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "used_libraries = [\n", " \"reasoning_from_scratch\", # for download functions\n", " \"torch\",\n", " \"tokenizers\"\n", "]\n", "\n", "for lib in used_libraries:\n", " print(f\"{lib} version: {version(lib)}\")" ] }, { "cell_type": "markdown", "id": "9cc4df40-b386-4bab-9722-a707d4088ca6", "metadata": {}, "source": [ "- While this is a \"from scratch\" book, as mentioned in the main chapters, the \"from scratch\" part refers to the reasoning techniques, not the LLM itself\n", "- Implementing an LLM from scratch is a whole book in itself; this is the topic of my [Build A Large Language Model (From Scratch)](https://github.com/rasbt/LLMs-from-scratch) book\n", "- However, for readers who are curious to see the code implementation that we use in this Build A Reasoning Model (From Scratch) book, this appendix lists the source code of the Qwen3 model that we are importing from the book's `reasoning_from_scratch` Python package via:\n", "\n", "```python\n", "from reasoning_from_scratch.qwen3 import Qwen3Model, Qwen3Tokenizer\n", "```\n", "\n", "- Note that the architecture code is very similar to GPT-2, which is covered in [Build A Large Language Model (From Scratch)](https://github.com/rasbt/LLMs-from-scratch)\n", "- While this book does not require familiarity with GPT-2, this appendix contains additional comparisons to GPT-2 for those readers familiar with it (I wrote this code by porting the GPT-2 model from my other book, bit by bit, over to the Qwen3 architecture)" ] }, { "cell_type": "markdown", "id": "484015ae-51ee-4b9d-963b-f4782625c601", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "e6e6d3a8-afe4-467c-8777-0da998b3deaa", "metadata": {}, "source": [ " \n", "## C.1 Root mean square layer normalization (RMSNorm)" ] }, { "cell_type": "markdown", "id": "0dbb6bf8-454a-4355-8bfc-494d9d84f3e6", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 2, "id": "65e0e4d6-195f-46c5-b53c-146d51a5f3f9", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "\n", "class RMSNorm(nn.Module):\n", " def __init__(\n", " self,\n", " emb_dim,\n", " eps=1e-6,\n", " bias=False,\n", " qwen3_compatible=True,\n", " ):\n", " super().__init__()\n", " self.eps = eps\n", " self.qwen3_compatible = qwen3_compatible\n", " self.scale = nn.Parameter(torch.ones(emb_dim))\n", " self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n", "\n", " def forward(self, x):\n", " input_dtype = x.dtype\n", "\n", " if self.qwen3_compatible:\n", " x = x.to(torch.float32)\n", "\n", " variance = x.pow(2).mean(dim=-1, keepdim=True)\n", " norm_x = x * torch.rsqrt(variance + self.eps)\n", " norm_x = norm_x * self.scale\n", "\n", " if self.shift is not None:\n", " norm_x = norm_x + self.shift\n", "\n", " return norm_x.to(input_dtype)" ] }, { "cell_type": "markdown", "id": "ec6898a0-5c5d-4de9-9c23-5240b61fd0ef", "metadata": {}, "source": [ " \n", "## C.2 Feed forward module" ] }, { "cell_type": "markdown", "id": "0284da21-8b89-4b7f-9f24-c197710f756a", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 3, "id": "b1dae31d-2811-468f-9b27-6ff8fcbfc6c3", "metadata": {}, "outputs": [], "source": [ "class FeedForward(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.fc1 = nn.Linear(\n", " cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"],\n", " bias=False\n", " )\n", " self.fc2 = nn.Linear(\n", " cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"],\n", " bias=False\n", " )\n", " self.fc3 = nn.Linear(\n", " cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"],\n", " bias=False\n", " )\n", "\n", " def forward(self, x):\n", " x_fc1 = self.fc1(x)\n", " x_fc2 = self.fc2(x)\n", " x = nn.functional.silu(x_fc1) * x_fc2\n", " return self.fc3(x)" ] }, { "cell_type": "markdown", "id": "27728b46-39d7-4fec-ae54-928f7cbc4c35", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "412349f4-6c07-4aaf-9aea-0460a4528ca6", "metadata": {}, "source": [ " \n", "## C.3 Rotary position embeddings (RoPE)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b558bb3c-695f-444d-9f17-7def2b11629a", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "\n", "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096,\n", " dtype=torch.float32):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", " inv_freq = 1.0 / (theta_base ** (\n", " torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float()\n", " / head_dim\n", " ))\n", " positions = torch.arange(context_length, dtype=dtype)\n", " angles = positions[:, None] * inv_freq[None, :]\n", " angles = torch.cat([angles, angles], dim=1)\n", "\n", " cos = torch.cos(angles)\n", " sin = torch.sin(angles)\n", "\n", " return cos, sin\n", "\n", "\n", "def apply_rope(x, cos, sin, offset=0):\n", " # x: (batch_size, num_heads, seq_len, head_dim)\n", " batch_size, num_heads, seq_len, head_dim = x.shape\n", " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", "\n", " # Split x into first half and second half\n", " x1 = x[..., : head_dim // 2] # First half\n", " x2 = x[..., head_dim // 2:] # Second half\n", "\n", " # Adjust sin and cos shapes, shape: (1, 1, seq_len, head_dim)\n", " cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) \n", " sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n", "\n", " rotated = torch.cat((-x2, x1), dim=-1)\n", " x_rotated = (x * cos) + (rotated * sin)\n", "\n", " # It's ok to use lower-precision after applying cos and sin rotation\n", " return x_rotated.to(dtype=x.dtype)" ] }, { "cell_type": "markdown", "id": "2691ddd6-e8cd-4ddf-bd0f-691eb75462a9", "metadata": {}, "source": [ " \n", "## C.4 Grouped query attention (GQA)" ] }, { "cell_type": "markdown", "id": "659d38e1-f47d-4600-86a2-18ffd8177ed0", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 5, "id": "f848b39b-4c4e-4bf5-acfc-efcb4e616b6a", "metadata": {}, "outputs": [], "source": [ "class GroupedQueryAttention(nn.Module):\n", " def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None,\n", " qk_norm=False, dtype=None):\n", " super().__init__()\n", " assert num_heads % num_kv_groups == 0\n", "\n", " self.num_heads = num_heads\n", " self.num_kv_groups = num_kv_groups\n", " self.group_size = num_heads // num_kv_groups\n", "\n", " if head_dim is None:\n", " assert d_in % num_heads == 0\n", " head_dim = d_in // num_heads\n", "\n", " self.head_dim = head_dim\n", " self.d_out = num_heads * head_dim\n", "\n", " self.W_query = nn.Linear(\n", " d_in, self.d_out, bias=False, dtype=dtype\n", " )\n", " self.W_key = nn.Linear(\n", " d_in, num_kv_groups * head_dim, bias=False,dtype=dtype\n", " )\n", " self.W_value = nn.Linear(\n", " d_in, num_kv_groups * head_dim, bias=False, dtype=dtype\n", " )\n", "\n", " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n", "\n", " if qk_norm:\n", " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n", " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n", " else:\n", " self.q_norm = self.k_norm = None\n", "\n", " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n", " b, num_tokens, _ = x.shape\n", "\n", " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n", " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n", " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n", "\n", " queries = queries.view(b, num_tokens, self.num_heads,\n", " self.head_dim).transpose(1, 2)\n", " keys_new = keys.view(b, num_tokens, self.num_kv_groups,\n", " self.head_dim).transpose(1, 2)\n", " values_new = values.view(b, num_tokens, self.num_kv_groups,\n", " self.head_dim).transpose(1, 2)\n", "\n", " if self.q_norm:\n", " queries = self.q_norm(queries)\n", " if self.k_norm:\n", " keys_new = self.k_norm(keys_new)\n", "\n", " queries = apply_rope(queries, cos, sin, offset=start_pos)\n", " keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)\n", "\n", " if cache is not None:\n", " prev_k, prev_v = cache\n", " keys = torch.cat([prev_k, keys_new], dim=2)\n", " values = torch.cat([prev_v, values_new], dim=2)\n", " else:\n", " start_pos = 0 # reset RoPE\n", " keys, values = keys_new, values_new\n", " next_cache = (keys, values)\n", "\n", " # Expand K and V to match number of heads\n", " keys = keys.repeat_interleave(self.group_size, dim=1)\n", " values = values.repeat_interleave(self.group_size, dim=1)\n", "\n", " attn_scores = queries @ keys.transpose(2, 3)\n", " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", " attn_weights = torch.softmax(\n", " attn_scores / self.head_dim**0.5, dim=-1\n", " )\n", "\n", " context = (attn_weights @ values).transpose(1, 2)\n", " context = context.reshape(b, num_tokens, self.d_out)\n", " return self.out_proj(context), next_cache" ] }, { "cell_type": "markdown", "id": "0c140c2e-667f-4d99-8711-ef9ca02a468f", "metadata": {}, "source": [ " \n", "## C.5 Transformer block" ] }, { "cell_type": "markdown", "id": "cacc2a69-7f87-40da-a849-f400975eccc3", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 6, "id": "e5ee80a5-cfb4-4406-b1e9-e161d7c27536", "metadata": {}, "outputs": [], "source": [ "class TransformerBlock(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.att = GroupedQueryAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " num_heads=cfg[\"n_heads\"],\n", " head_dim=cfg[\"head_dim\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n", " qk_norm=cfg[\"qk_norm\"],\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n", " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", "\n", " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", " x, next_cache = self.att(\n", " x, mask, cos, sin, start_pos=start_pos,cache=cache\n", " ) # Shape [batch_size, num_tokens, emb_size]\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed-forward block\n", " shortcut = x\n", " x = self.norm2(x)\n", " x = self.ff(x)\n", " x = x + shortcut # Add the original input back\n", "\n", " return x, next_cache" ] }, { "cell_type": "markdown", "id": "77afe014-aa3c-4852-b139-0a0e603f965d", "metadata": {}, "source": [ " \n", "## C.6 Main model code" ] }, { "cell_type": "code", "execution_count": 7, "id": "3f309641-3088-451f-81fa-4092589e88c8", "metadata": {}, "outputs": [], "source": [ "class Qwen3Model(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", "\n", " # Main model parameters\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"],\n", " dtype=cfg[\"dtype\"])\n", "\n", " self.trf_blocks = nn.ModuleList(\n", " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n", " )\n", " self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(\n", " cfg[\"emb_dim\"], cfg[\"vocab_size\"],\n", " bias=False, dtype=cfg[\"dtype\"]\n", " )\n", "\n", " # Reusable utilities\n", " if cfg[\"head_dim\"] is None:\n", " head_dim = cfg[\"emb_dim\"] // cfg[\"n_heads\"]\n", " else:\n", " head_dim = cfg[\"head_dim\"]\n", " cos, sin = compute_rope_params(\n", " head_dim=head_dim,\n", " theta_base=cfg[\"rope_base\"],\n", " context_length=cfg[\"context_length\"]\n", " )\n", " self.register_buffer(\"cos\", cos, persistent=False)\n", " self.register_buffer(\"sin\", sin, persistent=False)\n", " self.cfg = cfg\n", " self.current_pos = 0 # Track current position in KV cache\n", "\n", " def forward(self, in_idx, cache=None):\n", " # Forward pass\n", " tok_embeds = self.tok_emb(in_idx)\n", " x = tok_embeds\n", "\n", " num_tokens = x.shape[1]\n", " if cache is not None:\n", " pos_start = self.current_pos\n", " pos_end = pos_start + num_tokens\n", " self.current_pos = pos_end\n", " mask = torch.triu(\n", " torch.ones(\n", " pos_end, pos_end, device=x.device, dtype=torch.bool\n", " ),\n", " diagonal=1\n", " )[pos_start:pos_end, :pos_end]\n", " else:\n", " pos_start = 0 # Not strictly necessary but helps torch.compile\n", " mask = torch.triu(\n", " torch.ones(num_tokens, num_tokens, device=x.device,\n", " dtype=torch.bool),\n", " diagonal=1\n", " )\n", " # Prefill (no cache): mask starts as (num_tokens, num_tokens)\n", " # Cached decoding: mask starts as (num_tokens, prev_k_number_tokens + num_tokens)\n", " #\n", " # We add two leading dimensions so the mask becomes\n", " # (1, 1, num_tokens, num_tokens) during prefill and\n", " # (1, 1, num_tokens, total_key_tokens) during cached decoding.\n", " # These extra dimensions let PyTorch broadcast the same mask\n", " # across all batches and attention heads when applying it to\n", " # attn_scores of shape (batch, num_heads, num_tokens, total_key_tokens).\n", " mask = mask[None, None, :, :]\n", "\n", " for i, block in enumerate(self.trf_blocks):\n", " blk_cache = cache.get(i) if cache else None\n", " x, new_blk_cache = block(x, mask, self.cos, self.sin,\n", " start_pos=pos_start,\n", " cache=blk_cache)\n", " if cache is not None:\n", " cache.update(i, new_blk_cache)\n", "\n", " x = self.final_norm(x)\n", " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", " return logits\n", "\n", " def reset_kv_cache(self):\n", " self.current_pos = 0" ] }, { "cell_type": "markdown", "id": "019cf805-fe13-4b17-9dda-25fe1edece76", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 8, "id": "67c795ef-e04c-4992-938c-b83692bd4935", "metadata": {}, "outputs": [], "source": [ "QWEN_CONFIG_06_B = {\n", " \"vocab_size\": 151_936, # Vocabulary size\n", " \"context_length\": 40_960, # Context length that was used to train the model\n", " \"emb_dim\": 1024, # Embedding dimension\n", " \"n_heads\": 16, # Number of attention heads\n", " \"n_layers\": 28, # Number of layers\n", " \"hidden_dim\": 3072, # Size of the intermediate dimension in FeedForward\n", " \"head_dim\": 128, # Size of the heads in GQA\n", " \"qk_norm\": True, # Whether to normalize queries and keys in GQA\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 1_000_000.0, # The base in RoPE's \"theta\"\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", "}" ] }, { "cell_type": "markdown", "id": "7ad8e34c-8596-4df3-97cf-8fec6255bcf3", "metadata": {}, "source": [ "- The Qwen3 0.6B model supports up to 40,960 tokens in total\n", "- Within the 40,960 total token count, about 32,768 tokens are reserved for model outputs (the generated text) and 8,192 tokens for typical prompts (the user's question or instruction)\n", "- The model was trained on sequences of 32,768 tokens, so that length best reflects its effective working range\n", "- Either way, 32,768 are more than sufficient here (to provide some perspective a 40-thousand-token context is roughly half the length of the first Harry Potter book)" ] }, { "cell_type": "markdown", "id": "28c17624-e06c-4665-b028-e573d7f4323e", "metadata": {}, "source": [ " \n", "## C.7 KV cache" ] }, { "cell_type": "code", "execution_count": 9, "id": "f7794633-243f-4950-b642-71777196d947", "metadata": {}, "outputs": [], "source": [ "class KVCache:\n", " def __init__(self, n_layers):\n", " self.cache = [None] * n_layers\n", "\n", " def get(self, layer_idx):\n", " return self.cache[layer_idx]\n", "\n", " def update(self, layer_idx, value):\n", " self.cache[layer_idx] = value\n", "\n", " def get_all(self):\n", " return self.cache\n", "\n", " def reset(self):\n", " for i in range(len(self.cache)):\n", " self.cache[i] = None" ] }, { "cell_type": "markdown", "id": "c04d0ddf-cb0f-4911-be7e-dc42da7baaf2", "metadata": {}, "source": [ " \n", "## C.8 Tokenizer" ] }, { "cell_type": "code", "execution_count": 10, "id": "ebf9d3ea-42b1-4bcb-97d1-624afaefe237", "metadata": {}, "outputs": [], "source": [ "import re\n", "from tokenizers import Tokenizer\n", "\n", "class Qwen3Tokenizer:\n", " _SPECIALS = [\n", " \"<|endoftext|>\",\n", " \"<|im_start|>\", \"<|im_end|>\",\n", " \"<|object_ref_start|>\", \"<|object_ref_end|>\",\n", " \"<|box_start|>\", \"<|box_end|>\",\n", " \"<|quad_start|>\", \"<|quad_end|>\",\n", " \"<|vision_start|>\", \"<|vision_end|>\",\n", " \"<|vision_pad|>\", \"<|image_pad|>\", \"<|video_pad|>\",\n", " ]\n", " _SPLIT_RE = re.compile(r\"(<\\|[^>]+?\\|>)\")\n", "\n", " def __init__(self, tokenizer_file_path=\"tokenizer-base.json\",\n", " apply_chat_template=False,\n", " add_generation_prompt=False,\n", " add_thinking=False):\n", "\n", " self.apply_chat_template = apply_chat_template\n", " self.add_generation_prompt = add_generation_prompt\n", " self.add_thinking = add_thinking\n", "\n", " tok_path = Path(tokenizer_file_path)\n", " if not tok_path.is_file():\n", " raise FileNotFoundError(\n", " f\"Tokenizer file '{tok_path}' not found. \"\n", " )\n", "\n", " self._tok = Tokenizer.from_file(str(tok_path))\n", " self._special_to_id = {t: self._tok.token_to_id(t) \n", " for t in self._SPECIALS}\n", "\n", " self.pad_token = \"<|endoftext|>\"\n", " self.pad_token_id = self._special_to_id.get(self.pad_token)\n", "\n", " # Match HF behavior: chat model → <|im_end|>, base model → <|endoftext|>\n", " fname = tok_path.name.lower()\n", " if \"base\" in fname and \"reasoning\" not in fname:\n", " self.eos_token = \"<|endoftext|>\"\n", " else:\n", " self.eos_token = \"<|im_end|>\"\n", " self.eos_token_id = self._special_to_id.get(self.eos_token)\n", "\n", " def encode(self, prompt, chat_wrapped=None):\n", " if chat_wrapped is None:\n", " chat_wrapped = self.apply_chat_template\n", "\n", " stripped = prompt.strip()\n", " if stripped in self._special_to_id and \"\\n\" not in stripped:\n", " return [self._special_to_id[stripped]]\n", "\n", " if chat_wrapped:\n", " prompt = self._wrap_chat(prompt)\n", "\n", " ids = []\n", " for part in filter(None, self._SPLIT_RE.split(prompt)):\n", " if part in self._special_to_id:\n", " ids.append(self._special_to_id[part])\n", " else:\n", " ids.extend(self._tok.encode(part).ids)\n", " return ids\n", "\n", " def decode(self, token_ids):\n", " return self._tok.decode(token_ids, skip_special_tokens=False)\n", "\n", " def _wrap_chat(self, user_msg):\n", " s = f\"<|im_start|>user\\n{user_msg}<|im_end|>\\n\"\n", " if self.add_generation_prompt:\n", " s += \"<|im_start|>assistant\"\n", " if self.add_thinking:\n", " s += \"\\n\" # insert no tag, just a new line\n", " else:\n", " s += \"\\n\\n\\n\\n\\n\"\n", " return s" ] }, { "cell_type": "markdown", "id": "8e87228e-a531-4f06-b2db-c314dd2daf32", "metadata": {}, "source": [ " \n", "## C.9 Using the model" ] }, { "cell_type": "code", "execution_count": 11, "id": "720f4104-84c7-4e81-9265-d4075d7d4fc6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✓ qwen3/qwen3-0.6B-base.pth already up-to-date\n" ] }, { "data": { "text/plain": [ "Qwen3Model(\n", " (tok_emb): Embedding(151936, 1024)\n", " (trf_blocks): ModuleList(\n", " (0-27): 28 x TransformerBlock(\n", " (att): GroupedQueryAttention(\n", " (W_query): Linear(in_features=1024, out_features=2048, bias=False)\n", " (W_key): Linear(in_features=1024, out_features=1024, bias=False)\n", " (W_value): Linear(in_features=1024, out_features=1024, bias=False)\n", " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", " (q_norm): RMSNorm()\n", " (k_norm): RMSNorm()\n", " )\n", " (ff): FeedForward(\n", " (fc1): Linear(in_features=1024, out_features=3072, bias=False)\n", " (fc2): Linear(in_features=1024, out_features=3072, bias=False)\n", " (fc3): Linear(in_features=3072, out_features=1024, bias=False)\n", " )\n", " (norm1): RMSNorm()\n", " (norm2): RMSNorm()\n", " )\n", " )\n", " (final_norm): RMSNorm()\n", " (out_head): Linear(in_features=1024, out_features=151936, bias=False)\n", ")" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pathlib import Path\n", "import torch\n", "\n", "from reasoning_from_scratch.ch02 import get_device # noqa: F401\n", "from reasoning_from_scratch.qwen3 import download_qwen3_small\n", "\n", "# device = get_device() # Optional: Uncomment to use automatic device picker\n", "device = torch.device(\"cpu\")\n", "\n", "download_qwen3_small(kind=\"base\", tokenizer_only=False, out_dir=\"qwen3\")\n", "\n", "tokenizer_file_path = Path(\"qwen3\") / \"tokenizer-base.json\"\n", "model_file = Path(\"qwen3\") / \"qwen3-0.6B-base.pth\"\n", "\n", "tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_file_path)\n", "model = Qwen3Model(QWEN_CONFIG_06_B)\n", "model.load_state_dict(torch.load(model_file))\n", "\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 12, "id": "4b5ac463-c634-4470-a4ba-89432feeacb8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Large language models are artificial intelligence systems that can understand, generate, and process human language, enabling them to perform a wide range of tasks, from answering questions to writing articles, and even creating creative content.\n", "\n", "Time: 1.46 sec\n", "28 tokens/sec\n" ] } ], "source": [ "import time\n", "from reasoning_from_scratch.ch02 import (\n", " generate_text_basic_stream_cache,\n", " generate_stats\n", ")\n", "\n", "\n", "prompt = \"Explain large language models in a single sentence.\"\n", "input_token_ids_tensor = torch.tensor(\n", " tokenizer.encode(prompt),\n", " device=device\n", " ).unsqueeze(0)\n", "max_new_tokens = 200\n", "\n", "start_time = time.time()\n", "generated_ids = []\n", "\n", "for token in generate_text_basic_stream_cache(\n", " model=model,\n", " token_ids=input_token_ids_tensor,\n", " max_new_tokens=max_new_tokens,\n", " eos_token_id=tokenizer.eos_token_id\n", "):\n", " token_id = token.squeeze(0).tolist()\n", " print(\n", " tokenizer.decode(token_id),\n", " end=\"\",\n", " flush=True\n", " )\n", "\n", " next_token_id = token.squeeze(0)\n", " generated_ids.append(next_token_id) # Collect generated tokens\n", "\n", "end_time = time.time()\n", "\n", "output_token_ids_tensor = torch.cat(generated_ids, dim=0)\n", "generate_stats(output_token_ids_tensor, tokenizer, start_time, end_time)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }