{ "cells": [ { "cell_type": "markdown", "id": "4a379ea2-d831-4d46-b4fe-d6aa2b045389", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Reasoning Model (From Scratch) book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/reasoning-from-scratch\n", "
\n", "
\n", "\n", "
" ] }, { "cell_type": "markdown", "id": "c6378508-8676-4c95-8293-814938aafd49", "metadata": {}, "source": [ "# Chapter 6: Training Reasoning Models with Reinforcement Learning" ] }, { "cell_type": "markdown", "id": "a8bfc1e7-b142-4b35-8060-d4dcacfbf408", "metadata": {}, "source": [ "Packages that are being used in this notebook:" ] }, { "cell_type": "code", "execution_count": 1, "id": "9e8b6231-fe94-4b6c-9bab-b65c2014da5c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reasoning_from_scratch version: 0.1.16\n", "torch version: 2.10.0\n", "tokenizers version: 0.21.4\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "used_libraries = [\n", " \"reasoning_from_scratch\",\n", " \"torch\",\n", " \"tokenizers\" # Used by reasoning_from_scratch\n", "]\n", "\n", "for lib in used_libraries:\n", " print(f\"{lib} version: {version(lib)}\")" ] }, { "cell_type": "markdown", "id": "b26000d2-3dbc-47f0-8a9b-fc071b98f378", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "fc80352d-293c-4826-8df2-d9de2163e5f5", "metadata": {}, "source": [ " \n", "## 6.1 Introduction to reinforcement learning for LLMs" ] }, { "cell_type": "markdown", "id": "30c4da65-ceb6-4ad4-ac60-a34a03220666", "metadata": {}, "source": [ "- Inference-time scaling improves reasoning by using more compute per generated answer\n", "- Training-time scaling improves reasoning by using additional compute during training, which is the focus of this chapter" ] }, { "cell_type": "markdown", "id": "ae80a712-b6ac-4464-8c31-c3f83cf50215", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "ed859a4a-abcf-4e46-930f-c150dbabca0a", "metadata": {}, "source": [ "- Inference-time scaling and training-time scaling can (/should) also be combined, for example, by applying inference-time techniques after RL-based reasoning training\n", "- In practice, RL for LLMs is applied as a post-training stage on top of a pre-trained model or following instruction fine-tuning" ] }, { "cell_type": "markdown", "id": "1c88db65-cd77-4e93-8e81-077d7ff290db", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "8e5e0aec-b945-40cd-a503-ab651c08dc7b", "metadata": {}, "source": [ "- Pre-training builds general knowledge via next-token prediction, whereas RL refines model behavior by optimizing sequence-level objectives such as answer correctness or preferences\n", "- RL for LLMs includes reasoning training and preference tuning, but reasoning-focused RL can also be applied directly to a pre-trained base model, as shown by DeepSeek-R1\n", "- Training reasoning directly on the base model produces a weaker but still capable model (but it offers a simpler setting for understanding what the reasoning stage contributes)" ] }, { "cell_type": "markdown", "id": "ff833682-54c7-48c4-892c-63591c56eac9", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "5efcf75b-9bb0-4cf8-b82e-1b36765eeeca", "metadata": {}, "source": [ " \n", "### 6.1.1 The original reinforcement learning pipeline with human feedback (RLHF)" ] }, { "cell_type": "markdown", "id": "6609266f-e01d-404e-b789-d8630dbc52b8", "metadata": {}, "source": [ "- RLHF was introduced in the InstructGPT work in 2022 and uses human preference labels to train LLMs (this was a key step in turning GPT-3 into the original ChatGPT)\n", "- Unlike pre-training and supervised fine-tuning, which optimize next-token prediction, RLHF optimizes models based on human preference labels of the model responses" ] }, { "cell_type": "markdown", "id": "d87dff9d-ef5b-4075-b528-1776ab377fe3", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "9b2ad33b-ee36-4d89-a28a-401827650ee4", "metadata": {}, "source": [ " \n", "### 6.1.2 From human feedback to verifiable rewards (RLVR)" ] }, { "cell_type": "markdown", "id": "46db1518-62a0-405f-9a95-0a80f60d444e", "metadata": {}, "source": [ "- RLHF requires training a separate reward model, which is often a large and expensive LLM\n", "- RLVR replaces the learned reward model with automatically verifiable, deterministic rewards" ] }, { "cell_type": "markdown", "id": "41735b73-2dae-47f5-9274-55c78bfc9606", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d556e55d-175d-42dd-bbc9-b1b47dcc18d3", "metadata": {}, "source": [ "- The popularity of RLVR was driven in large part by the success of DeepSeek-R1 in 2025, which demonstrated strong reasoning performance without relying on human preference data or a learned reward model\n", "- DeepSeek-R1 trained reasoning behavior using automatically verifiable rewards, such as correctness checks for math problems and code compilation or execution for programming tasks\n", "- While this book focuses on math-based verification, the underlying idea is similar to code verification: rewards are computed automatically using binary success signals" ] }, { "cell_type": "markdown", "id": "76619875-2f05-46d7-8ecc-91f03978dd01", "metadata": {}, "source": [ " \n", "## 6.2 Reinforcement learning with verifiable rewards walkthrough using GRPO" ] }, { "cell_type": "markdown", "id": "eb104a51-262a-4e32-aef8-7c9a9253d70f", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "3b33b494-c311-4f3e-a0bb-50ad0f4a13c0", "metadata": {}, "source": [ "- Now, after introducing the big picture and seeing how RL fits into the development cycle of LLMs, we will implement RLVR to train a reasoning model (similar to DeepSeek-R1-Zero but on a much smaller scale, as a comparable run would cost multiple hundreds of thousands of dollars in GPU costs)\n", "- RL for LLMs uses a so-called policy gradient algorithm that is used to update the LLM we want to train (which is called the \"policy\" in RL contexts)\n", "- A popular policy gradient algorithm for RLHF is proximal policy optimization (PPO); we could use the same algorithm in RLVR\n", "- However, the DeepSeek team used a simpler algorithm when they trained the DeepSeek-R1 reasoning models, namely group relative policy optimization (GRPO) (first used in DeepSeekMath)\n", "- GRPO is more resource-friendly, because in PPO we have another LLM compute the value function; in GRPO, we don't need that, as it derives its learning signal from relative comparisons within a group of sampled responses\n", "- Interested readers can find a more detailed side-by-side comparison between PPO and GRPO in my article [The State of Reinforcement Learning for LLM Reasoning](https://magazine.sebastianraschka.com/p/the-state-of-llm-reasoning-model-training)\n", "- In this chapter, we implement RLVR using GRPO\n", "- Additionally, the next chapter introduces additional improvements to GRPO to improve the training stability and resulting modeling performance" ] }, { "cell_type": "markdown", "id": "7ff4e33f-2d63-409f-a701-62d40d291999", "metadata": {}, "source": [ "### 6.2.1 High-level GRPO intuition via a chef analogy" ] }, { "cell_type": "markdown", "id": "8ab5125f-5eae-4569-adb9-78a94a857479", "metadata": {}, "source": [ "- Since GRPO can look complicated at first glance, I wanted to start this section with a general big-picture overview using a \"chef & cooking\" analogy to introduce the terminology and provide some intuition" ] }, { "cell_type": "markdown", "id": "e400f5b6-df18-4b7a-bb4b-be57b67ecbf4", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "386f90a0-318a-41a2-a635-d5c4935c65b0", "metadata": {}, "source": [ "- In most RL-for-LLMs contexts, rollout and completion are terms that are used interchangeably" ] }, { "cell_type": "markdown", "id": "26f0092f-fc60-42a7-ab85-50c30fea8a53", "metadata": {}, "source": [ "### 6.2.2 The high-level GRPO procedure" ] }, { "cell_type": "markdown", "id": "646504cb-11ec-4dd3-b771-d0cf1fe2ffb8", "metadata": {}, "source": [ "- The technical roadmap for implementing GRPO in the following sections:" ] }, { "cell_type": "markdown", "id": "4ed403d3-99c4-4fe5-8733-3f3e81a69bc3", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "45ea6cff-d0ab-45da-99ab-0b46d7565bf6", "metadata": {}, "source": [ " \n", "## 6.3 Loading a pre-trained model" ] }, { "cell_type": "markdown", "id": "c89a2d64-858e-4f96-a83e-ba77a4dce3a6", "metadata": {}, "source": [ "- The code in this chapter is identical to the one in previous chapters" ] }, { "cell_type": "markdown", "id": "20dae59f-317e-4a17-826c-19cd39285c41", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 2, "id": "78cf0c64-f245-4558-804b-0170e9bb4887", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using Apple Silicon GPU (MPS)\n", "✓ qwen3/qwen3-0.6B-base.pth already up-to-date\n" ] } ], "source": [ "import torch\n", "\n", "from reasoning_from_scratch.ch02 import get_device\n", "from reasoning_from_scratch.ch03 import (\n", " load_model_and_tokenizer\n", ")\n", "\n", "device = get_device()\n", "device = torch.device(\"cpu\")\n", "\n", "model, tokenizer = load_model_and_tokenizer(\n", " which_model=\"base\",\n", " device=device,\n", " use_compile=False\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3e3f518-6525-42e6-8888-bc078f907c1f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \\boxed{58}" ] } ], "source": [ "from reasoning_from_scratch.ch03 import render_prompt\n", "from reasoning_from_scratch.ch04 import (\n", " generate_text_stream_concat_flex,\n", " generate_text_top_p_stream_cache\n", ")\n", "\n", "raw_prompt = (\n", " \"Half the value of $3x-9$ is $x+37$. \"\n", " \"What is the value of $x$?\"\n", ")\n", "prompt = render_prompt(raw_prompt)\n", "\n", "torch.manual_seed(0)\n", "response = generate_text_stream_concat_flex(\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=2048, verbose=True,\n", " generate_func=generate_text_top_p_stream_cache,\n", " temperature=0.9,\n", " top_p=0.9\n", ")" ] }, { "cell_type": "markdown", "id": "e72a81ec-a7ec-48e3-983d-42eb387c0b8e", "metadata": {}, "source": [ " \n", "## 6.4 Loading a MATH training subset" ] }, { "cell_type": "markdown", "id": "41628fa8-20eb-4fc6-9879-c7ff227f608b", "metadata": {}, "source": [ "- We use a non-overlapping training subset derived from the original MATH dataset that explicitly excludes the MATH-500 examples used for model evaluation in the previous chapters (for more information about how the dataset was prepared, please see https://github.com/rasbt/math_full_minus_math500)" ] }, { "cell_type": "markdown", "id": "1dae3630-ee34-48fe-a01b-ea8ff1e8fccb", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "4f285e5e-5586-41ed-9d8c-ff36a7a3b102", "metadata": {}, "source": [ "- The following `load_math_train` function is similar to the [load_math500_test](https://github.com/rasbt/reasoning-from-scratch/blob/main/reasoning_from_scratch/ch03.py#L422) function in chapter 3, except that we specify a different file path" ] }, { "cell_type": "code", "execution_count": 4, "id": "7c660163-5c8d-4349-9546-5953a87c7440", "metadata": {}, "outputs": [], "source": [ "import json\n", "import requests\n", "from pathlib import Path\n", "\n", "def load_math_train(local_path=\"math_train.json\", save_copy=True):\n", " local_path = Path(local_path)\n", "\n", " url = (\n", " \"https://raw.githubusercontent.com/rasbt/\"\n", " \"math_full_minus_math500/refs/heads/main/\"\n", " \"math_full_minus_math500.json\"\n", " )\n", " backup_url = (\n", " \"https://f001.backblazeb2.com/file/reasoning-from-scratch/\"\n", " \"MATH/math_full_minus_math500.json\"\n", " )\n", "\n", " if local_path.exists():\n", " with local_path.open(\"r\", encoding=\"utf-8\") as f:\n", " data = json.load(f)\n", " else:\n", " try:\n", " r = requests.get(url, timeout=30)\n", " r.raise_for_status()\n", " except requests.RequestException:\n", " print(\"Using backup URL.\")\n", " r = requests.get(backup_url, timeout=30)\n", " r.raise_for_status()\n", "\n", " data = r.json()\n", "\n", " if save_copy:\n", " with local_path.open(\"w\", encoding=\"utf-8\") as f:\n", " json.dump(data, f, indent=2)\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 5, "id": "ed5337fd-5b27-4859-ac06-df44edb0214d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset size: 12000\n" ] } ], "source": [ "math_train = load_math_train()\n", "\n", "print(\"Dataset size:\", len(math_train))" ] }, { "cell_type": "code", "execution_count": 6, "id": "cd410244-6a7b-408f-a696-9a08c3d3c39c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'answer': '6',\n", " 'level': 'Level 3',\n", " 'problem': 'Sam is hired for a 20-day period. On days that he works, he earns '\n", " '$\\\\$$60. For each day that he does not work, $\\\\$$30 is '\n", " 'subtracted from his earnings. At the end of the 20-day period, he '\n", " 'received $\\\\$$660. How many days did he not work?',\n", " 'solution': 'Call $x$ the number of days Sam works and $y$ the number of days '\n", " 'he does not. We can set up the following system of equations to '\n", " 'represent the given information: \\\\begin{align*}\\n'\n", " 'x+y &= 20 \\\\\\\\\\n'\n", " '60x - 30y &= 660 \\\\\\\\\\n'\n", " '\\\\end{align*} The first equation represents the total number of '\n", " 'days Sam works, and the second equation represents his total '\n", " 'profit. Solving for $x$ in the first equation yields $x = 20 - '\n", " 'y$. Substituting into the second equation gives $60(20-y) - 30y '\n", " '= 660$. Canceling a factor of $10$ and multiplying out gives '\n", " '$120 - 6y - 3y = 66$. This simplifies to $-9y = -54$, or $y = '\n", " '6$. Thus, Sam did not work for $\\\\boxed{6}$ days.',\n", " 'type': 'Algebra',\n", " 'unique_id': 4}\n" ] } ], "source": [ "from pprint import pprint\n", "\n", "pprint(math_train[4])" ] }, { "cell_type": "markdown", "id": "f230c7a1-0bb0-4641-b720-5c17729e2e77", "metadata": {}, "source": [ "- Note that we only need the `\"answer\"` and `\"problem\"` fields\n", "- In theory, it may be tempting to use the `\"solution\"`, but here we want to let the model explore solutions freely (instead of learning a specific solution and style)" ] }, { "cell_type": "markdown", "id": "3e086f43-c568-408c-8b56-a349b046164d", "metadata": {}, "source": [ " \n", "## 6.5 Sampling rollouts" ] }, { "cell_type": "markdown", "id": "b8a5bac2-dc3f-404b-866a-f640f2b5ac7e", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "3d8c5fb8-ff59-40d4-9f58-0087c55c459f", "metadata": {}, "source": [ "- Rollout is RL jargon for generated response" ] }, { "cell_type": "markdown", "id": "7ce30bf5-8b6c-4cb2-836b-97488bcbbad7", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "2b7da7dd-0f75-4bcc-bd86-dfc172dfca68", "metadata": {}, "source": [ "- We need `@torch.no_grad` as we don't want to build a graph and backpropagate through this, but `@torch.inference_mode` doesn't work (it does too much) and results in\n", "\n", "> RuntimeError: Inference tensors cannot be saved for backward. Please do not use Tensors created in inference mode in computation tracked by autograd. To work around this, you can make a clone to get a normal tensor and use it in autograd, or use `torch.no_grad()` instead of `torch.inference_mode()`." ] }, { "cell_type": "code", "execution_count": 7, "id": "d79cf840-8d76-4490-83c0-87344433d016", "metadata": {}, "outputs": [], "source": [ "from reasoning_from_scratch.qwen3 import KVCache\n", "from reasoning_from_scratch.ch04 import top_p_filter\n", "\n", "\n", "@torch.no_grad()\n", "def sample_response(\n", " model,\n", " tokenizer,\n", " prompt,\n", " device,\n", " max_new_tokens=512,\n", " temperature=0.8,\n", " top_p=0.9,\n", "):\n", " input_ids = torch.tensor(\n", " tokenizer.encode(prompt),\n", " device=device\n", " )\n", "\n", " cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n", " model.reset_kv_cache()\n", " logits = model(input_ids.unsqueeze(0), cache=cache)[:, -1]\n", "\n", " generated = []\n", " for _ in range(max_new_tokens):\n", " if temperature and temperature != 1.0:\n", " logits = logits / temperature\n", "\n", " probas = torch.softmax(logits, dim=-1)\n", " probas = top_p_filter(probas, top_p)\n", " next_token = torch.multinomial(\n", " probas.cpu(), num_samples=1\n", " ).to(device)\n", "\n", " token_id = next_token.item()\n", " generated.append(token_id)\n", "\n", " if (\n", " tokenizer.eos_token_id is not None\n", " and token_id == tokenizer.eos_token_id\n", " ):\n", " break\n", " logits = model(next_token, cache=cache)[:, -1]\n", "\n", " full_token_ids = torch.cat(\n", " [input_ids,\n", " torch.tensor(generated, device=device, dtype=input_ids.dtype),]\n", " )\n", " return full_token_ids, input_ids.numel(), tokenizer.decode(generated)" ] }, { "cell_type": "markdown", "id": "4e79b893-5423-4cff-9bc6-dfe7a0ac17d5", "metadata": {}, "source": [ "- There is nothing new here\n", "- The code above is simply a leaner version of what we have been developing previously; it combines the [generate_text_basic_stream_cache](https://github.com/rasbt/reasoning-from-scratch/blob/main/reasoning_from_scratch/ch02.py#L57) function from chapter 2 with temperature and top-p sampling from chapter 4 directly\n", "- Instead of yielding each token, we also now just collect the tokens in a tensor as we don't need to print the generated tokens live" ] }, { "cell_type": "code", "execution_count": 8, "id": "cf6554ab-38f8-448a-a985-1da3ce8aae98", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \\boxed{58}<|endoftext|>\n" ] } ], "source": [ "torch.manual_seed(0)\n", "\n", "raw_prompt = (\n", " \"Half the value of $3x-9$ is $x+37$. \"\n", " \"What is the value of $x$?\"\n", ")\n", "prompt = render_prompt(raw_prompt)\n", "\n", "token_ids, prompt_len, answer_text = sample_response(\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompt=prompt,\n", " device=device,\n", " max_new_tokens=512,\n", " temperature=0.9,\n", " top_p=0.9,\n", " )\n", "\n", "print(answer_text)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ed8527a4-31a8-4238-ac94-c5b49dbd8f62", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Let's solve the problem step by step.\n", "\n", "**Given:**\n", "\\[\n", "\\text{Half the value of } 3x - 9 \\text{ is } x + 37.\n", "\\]\n", "\n", "**Step 1: Translate the statement into an equation.**\n", "\\[\n", "\\frac{1}{2} (3x - 9) = x + 37\n", "\\]\n", "\n", "**Step 2: Eliminate the fraction by multiplying both sides by 2.**\n", "\\[\n", "3x - 9 = 2(x + 37)\n", "\\]\n", "\n", "**Step 3: Distribute the 2 on the right side.**\n", "\\[\n", "3x - 9 = 2x + 74\n", "\\]\n", "\n", "**Step 4: Subtract \\(2x\\) from both sides to get the \\(x\\)-terms on one side.**\n", "\\[\n", "3x - 2x - 9 = 74\n", "\\]\n", "\\[\n", "x - 9 = 74\n", "\\]\n", "\n", "**Step 5: Add 9 to both sides to solve for \\(x\\).**\n", "\\[\n", "x = 74 + 9\n", "\\]\n", "\\[\n", "x = 83\n", "\\]\n", "\n", "**Final Answer:**\n", "\\[\n", "\\boxed{83}\n", "\\]<|endoftext|>\n" ] } ], "source": [ "torch.manual_seed(5)\n", "\n", "token_ids, prompt_len, answer_text = sample_response(\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompt=prompt,\n", " device=device,\n", " max_new_tokens=512,\n", " temperature=0.9,\n", " top_p=0.9,\n", " )\n", "\n", "print(answer_text)" ] }, { "cell_type": "markdown", "id": "2043a669-5410-4c8a-9934-c2a5e0d577ef", "metadata": {}, "source": [ "- In practice, we would call sample_response multiple times to generate rollouts\n", "- To keep the GRPO walkthrough simple and aligned with figure 6.13, we instead assume the model generated the following four responses:" ] }, { "cell_type": "code", "execution_count": 10, "id": "ffd0e902-d34e-449b-8dfe-9021fb3e5901", "metadata": {}, "outputs": [], "source": [ "rollouts = [\n", " r\"\\boxed{83}\",\n", " r\"The correct answer is \\boxed{83}\",\n", " r\"The final answer is 83\",\n", " r\"We get \\boxed{38}\",\n", "]" ] }, { "cell_type": "markdown", "id": "e1060430-5aa6-4171-a4d6-5e78e068b354", "metadata": {}, "source": [ " \n", "## 6.6 Calculating rewards" ] }, { "cell_type": "markdown", "id": "1c986eaf-de71-4071-8f52-a5812ff73ec0", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "6790c682-c835-46d0-9587-b8d64d741e98", "metadata": {}, "source": [ "- The rewards are simply correctness rewards, similar to chapter 3\n", "- However, there is also an implicit format reward, the reward of 1.0 is only given if the final answer is in the `\\boxed{}` format (via `fallback=None`)" ] }, { "cell_type": "code", "execution_count": 11, "id": "1670c8f1-1287-4202-9abb-574911757001", "metadata": {}, "outputs": [], "source": [ "from reasoning_from_scratch.ch03 import (\n", " extract_final_candidate, grade_answer\n", ")\n", "\n", "def reward_rlvr(answer_text, ground_truth):\n", " extracted = extract_final_candidate(\n", " answer_text, fallback=None # Require \\boxed{}\n", " )\n", " if not extracted:\n", " return 0.0\n", " correct = grade_answer(extracted, ground_truth)\n", " return float(correct)" ] }, { "cell_type": "code", "execution_count": 12, "id": "64c41031-9ef6-4eb7-ad72-f928a35eb3b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: '\\\\boxed{83}'\n", "Reward: 1.0\n", "\n", "Answer: 'The correct answer is \\\\boxed{83}'\n", "Reward: 1.0\n", "\n", "Answer: 'The final answer is 83'\n", "Reward: 0.0\n", "\n", "Answer: 'We get \\\\boxed{38}'\n", "Reward: 0.0\n", "\n" ] } ], "source": [ "rollouts = [\n", " r\"\\boxed{83}\",\n", " r\"The correct answer is \\boxed{83}\",\n", " r\"The final answer is 83\",\n", " r\"We get \\boxed{38}\",\n", "]\n", "rollout_rewards = []\n", "\n", "for answer in rollouts:\n", " reward = reward_rlvr(answer_text=answer, ground_truth=\"83\")\n", " print(f\"Answer: {answer!r}\")\n", " print(f\"Reward: {reward}\\n\")\n", " rollout_rewards.append(reward)" ] }, { "cell_type": "markdown", "id": "d809f38e-839d-45da-af13-f16dfa5a6ea9", "metadata": {}, "source": [ "- Note: The DeepSeek-R1 team tried to use process reward models to score intermediate solution steps when training the model\n", "- However, these attempts were unsuccessful, and the researchers concluded that it is better to only train on the final answer correctness rewards without intermediate rewards" ] }, { "cell_type": "markdown", "id": "9b618020-123a-4800-b3d2-699cd7e6e694", "metadata": {}, "source": [ " \n", "## 6.7 Preparing learning signals from rollouts via advantages" ] }, { "cell_type": "markdown", "id": "adbda36a-4da4-488f-b6bb-80d8fb2d93cd", "metadata": {}, "source": [ "- The \"GR\" (group relative) in GRPO refers to the fact that GRPO generates multiple answers (rollouts) per prompt, and compares them relative to each other to construct a learning signal" ] }, { "cell_type": "markdown", "id": "b4e7b2d8-b705-42b2-a73d-f7883eda7efe", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "de36577d-ac54-440d-b467-ccb21c104877", "metadata": {}, "source": [ "- The formula is quite simple:\n", "\n", "$$\\text{advantages}_i = \\frac{r_i - \\mu_r}{\\sigma_r + \\epsilon}$$" ] }, { "cell_type": "markdown", "id": "b899fd5f-9a44-4437-bfa9-da38881ab87b", "metadata": {}, "source": [ "- Here, $r_i$ denotes the reward of the $i$-th rollout, $\\mu_r$ is the mean reward across the group of rollouts, $\\sigma_r$ is the corresponding standard deviation, and $\\epsilon$ is a small constant added for numerical stability to avoid zero-division errors" ] }, { "cell_type": "code", "execution_count": 13, "id": "33e1ba67-a5b7-4b18-8590-e6a6c0d6767c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([1., 1., 0., 0.])\n" ] } ], "source": [ "rewards = torch.tensor(rollout_rewards, device=device)\n", "print(rewards)" ] }, { "cell_type": "code", "execution_count": 14, "id": "60e3d1b3-3586-4c60-96b8-199cf7d69417", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ 0.8659, 0.8659, -0.8659, -0.8659])\n" ] } ], "source": [ "advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)\n", "\n", "print(advantages)" ] }, { "cell_type": "markdown", "id": "3e740c9a-60f3-404f-b950-87e631af3fe0", "metadata": {}, "source": [ "- Note that if all rewards in a group are identical, for example, all 0 or all 1, then $r_i - \\mu_r = 0$ for all $i$ rollouts\n", "- This means the model is not updated if all answers are correct or all answers are incorrect" ] }, { "cell_type": "markdown", "id": "9f2bdd0f-549b-4b42-87f3-c0fa18a9be9a", "metadata": {}, "source": [ " \n", "## 6.8 Scoring rollouts with sequence log-probabilities" ] }, { "cell_type": "markdown", "id": "97516f0e-c4c9-4811-8727-7789950d270d", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "f50bca6d-adec-4d6a-ae9e-1361392fe555", "metadata": {}, "source": [ "- In the previous chapter, we implemented an `avg_logprob_answer` function that calculates per-token log-probabilities for the answer tokens\n", "- These averaged log-probabilities are also often referred to as token-level log-probabilities, and are commonly used for scoring LLM answers\n", "- This averaging is preferred for scoring as it provides length-normalization\n", "- Mathematically, this can be expressed as $\\frac{1}{T} \\sum_{t=1}^{T} \\log p_W(y_t \\mid y_{)\n" ] } ], "source": [ "def sequence_logprob_draft(model, token_ids, prompt_len):\n", " logits = model(token_ids.unsqueeze(0)).squeeze(0).float()\n", " logprobs = torch.log_softmax(logits, dim=-1)\n", "\n", " # Positions whose next-token probabilities we want\n", " # These correspond to predicting token_ids[t + 1] from position t\n", " start = prompt_len - 1\n", " end = token_ids.shape[0] - 1\n", "\n", " t_idx = torch.arange(start, end, device=token_ids.device)\n", " next_tokens = token_ids[start + 1 : end + 1]\n", " next_token_logps = logprobs[t_idx, next_tokens]\n", "\n", " # Sum log-probabilities over the answer tokens\n", " return torch.sum(next_token_logps)\n", "\n", "print(sequence_logprob_draft(model, token_ids, prompt_len))" ] }, { "cell_type": "markdown", "id": "947eee07-3f48-4dd8-b315-f2c7de99c5f5", "metadata": {}, "source": [ "- Note that we don't use `.item()` in `torch.sum(next_token_logps)` so that PyTorch returns a tensor (rather than a Python float), which is important for the gradient calculation\n", "- As we can see, the resulting value (-16.2998) is almost identical to that we got previously when rescaling the `avg_logprob_val` by the number of answer tokens (-16.2390); the minor differences can be attributed to floating point rounding behavior " ] }, { "cell_type": "markdown", "id": "27ee0964-c5df-42e2-aacf-79cdf6b6456b", "metadata": {}, "source": [ "- Below, we will rewrite the function using torch.gather, which is a bit more idiomatic in PyTorch and is a bit better optimized for GPUs\n", "- However, both functions are mathematically equivalent" ] }, { "cell_type": "code", "execution_count": 19, "id": "c5f53255-c3c0-404d-b0a8-192fc35c582e", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(-16.2998, grad_fn=)\n" ] } ], "source": [ "def sequence_logprob(model, token_ids, prompt_len):\n", " logits = model(token_ids.unsqueeze(0)).squeeze(0).float()\n", " logprobs = torch.log_softmax(logits, dim=-1)\n", " selected = logprobs[:-1].gather(\n", " 1, token_ids[1:].unsqueeze(-1)\n", " ).squeeze(-1)\n", " return torch.sum(selected[prompt_len - 1:])\n", "\n", "print(sequence_logprob(model, token_ids, prompt_len))" ] }, { "cell_type": "code", "execution_count": 20, "id": "77874ce7-b29a-43f7-b863-562e80f9289d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: \\boxed{83}\n", "Logprob: -7.9243\n", "\n", "Answer: The correct answer is \\boxed{83}\n", "Logprob: -20.1546\n", "\n", "Answer: The final answer is 83\n", "Logprob: -16.6130\n", "\n", "Answer: We get \\boxed{38}\n", "Logprob: -23.3677\n", "\n" ] } ], "source": [ "rollouts = [\n", " r\"\\boxed{83}\",\n", " r\"The correct answer is \\boxed{83}\",\n", " r\"The final answer is 83\",\n", " r\"We get \\boxed{38}\",\n", "]\n", "\n", "rollout_logps = []\n", "\n", "for text in rollouts:\n", " token_ids = tokenizer.encode(prompt + \" \" + text)\n", " logprob = sequence_logprob(\n", " model=model,\n", " token_ids=torch.tensor(token_ids, device=device),\n", " prompt_len=prompt_len,\n", " )\n", "\n", " print(f\"Answer: {text}\")\n", " print(f\"Logprob: {logprob.item():.4f}\\n\")\n", "\n", " rollout_logps.append(logprob)" ] }, { "cell_type": "markdown", "id": "7ba0f115-db52-4a57-83a8-a559634dafa6", "metadata": {}, "source": [ "- The trend here is that shorter and more concise answers receive higher (less negative) sequence-level log-probabilities\n", "- And the lowest score is assigned to the only answer containing an incorrect value (38 instead of 83)\n", "- Overall, summed log-probabilities favor concise and correct outputs" ] }, { "cell_type": "markdown", "id": "6b7b630a-93a2-4d90-8a7a-5dcc90dfd974", "metadata": {}, "source": [ " \n", "## 6.9 From advantages to policy updates via the GRPO loss" ] }, { "cell_type": "markdown", "id": "6b69ddff-149f-4449-8ddf-0a1e1eea77d5", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 21, "id": "13bde248-1396-49cf-a6d7-04569bd6a140", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([ -7.9243, -20.1546, -16.6130, -23.3677], grad_fn=)\n" ] } ], "source": [ "logps = torch.stack(rollout_logps)\n", "print(logps)" ] }, { "cell_type": "code", "execution_count": 22, "id": "9fde6ffa-af6c-4a8e-a218-42856d181a13", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(-2.5764, grad_fn=)\n" ] } ], "source": [ "pg_loss = -(advantages.detach() * logps).mean()\n", "print(pg_loss)" ] }, { "cell_type": "markdown", "id": "85571031-c486-4c64-b258-a9df8047a15b", "metadata": {}, "source": [ "- We need the `.detach()` because we want to treat the `advantages` as fixed learning signals; this way, we ensure that we only backprop through the logprobs\n", "- We need the negative sign because PyTorch optimizers minimize by default, and here we want to maximize the logprob-weighted advantages" ] }, { "cell_type": "markdown", "id": "6954ad35-ef9a-4ded-a1ce-bdf652a06826", "metadata": {}, "source": [ "- In mathematical notation, we can write the policy gradient loss as follows:\n", "\n", "$$\\mathcal{L}_{\\mathrm{PG}}\n", "= -\\frac{1}{N} \\sum_{i=1}^{N} A_i \\sum_{t=1}^{T_i} \\log p_W\\!\\left( y_t^{(i)} \\mid y_{" ] }, { "cell_type": "code", "execution_count": 23, "id": "e8d6d939-cfe3-4b5b-bc25-32a669f81db4", "metadata": {}, "outputs": [], "source": [ "def compute_grpo_loss(\n", " model,\n", " tokenizer,\n", " example,\n", " device,\n", " num_rollouts=2,\n", " max_new_tokens=256,\n", " temperature=0.8,\n", " top_p=0.9,\n", "):\n", " assert num_rollouts >= 2\n", " roll_logps, roll_rewards, samples = [], [], []\n", " prompt = render_prompt(example[\"problem\"])\n", "\n", " was_training = model.training\n", " model.eval()\n", "\n", " for _ in range(num_rollouts):\n", " # Stage 1: generate rollouts\n", " token_ids, prompt_len, text = sample_response(\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompt=prompt,\n", " device=device,\n", " max_new_tokens=max_new_tokens,\n", " temperature=temperature,\n", " top_p=top_p,\n", " )\n", " # Stage 2: compute rewards\n", " reward = reward_rlvr(text, example[\"answer\"])\n", " \n", " # Stage 4: compute logprobs\n", " logp = sequence_logprob(model, token_ids, prompt_len)\n", "\n", " roll_logps.append(logp)\n", " roll_rewards.append(reward)\n", " samples.append(\n", " {\n", " \"text\": text,\n", " \"reward\": reward,\n", " \"gen_len\": token_ids.numel() - prompt_len,\n", " }\n", " )\n", "\n", " if was_training:\n", " model.train()\n", "\n", " # Stage 2: collect all rewards\n", " rewards = torch.tensor(roll_rewards, device=device)\n", "\n", " # Stage 3: compute advantages\n", " advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)\n", "\n", " # Stage 4: collect all logprobs\n", " logps = torch.stack(roll_logps)\n", "\n", " # Stage 5: compute policy gradient loss\n", " pg_loss = -(advantages.detach() * logps).mean()\n", " loss = pg_loss # In the next chapter we add a KL term here\n", "\n", " return {\n", " \"loss\": loss.item(),\n", " \"pg_loss\": pg_loss.item(),\n", " \"rewards\": roll_rewards,\n", " \"advantages\": advantages.detach().cpu().tolist(),\n", " \"samples\": samples,\n", " \"loss_tensor\": loss,\n", " }" ] }, { "cell_type": "markdown", "id": "21303aae-56a6-45a8-90dd-e0b5ea1f9e0c", "metadata": {}, "source": [ "- The stages in the code comments map to the stages in the GRPO figure\n", "- Note that following stage 1, we have stages 2 and 4 (instead of 3 and 4) in the code comments, since this results in a simpler code implementation (so that we don't have to implement multiple for-loops)" ] }, { "cell_type": "code", "execution_count": 24, "id": "f42a178a-5557-443b-9e32-9d96e63cb425", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'advantages': [0.0, 0.0],\n", " 'loss': -0.0,\n", " 'loss_tensor': tensor(-0., grad_fn=),\n", " 'pg_loss': -0.0,\n", " 'rewards': [0.0, 0.0],\n", " 'samples': [{'gen_len': 4, 'reward': 0.0, 'text': ' 14<|endoftext|>'},\n", " {'gen_len': 256,\n", " 'reward': 0.0,\n", " 'text': ' 4\\n'\n", " '\\n'\n", " \"To solve the problem, let's break it down step by \"\n", " 'step:\\n'\n", " '\\n'\n", " '1. **Define Variables:**\\n'\n", " ' - Let \\\\( x \\\\) be the number of days Sam works.\\n'\n", " ' - Then, the number of days he does not work is \\\\( '\n", " '20 - x \\\\).\\n'\n", " '\\n'\n", " '2. **Set Up the Earnings Equation:**\\n'\n", " ' - For each day he works, he earns \\\\$60.\\n'\n", " ' - For each day he does not work, he loses \\\\$30.\\n'\n", " ' - His total earnings are \\\\$660.\\n'\n", " '\\n'\n", " ' The equation is:\\n'\n", " ' \\\\[\\n'\n", " ' 60x - 30(20 - x) = 660\\n'\n", " ' \\\\]\\n'\n", " '\\n'\n", " '3. **Simplify the Equation:**\\n'\n", " ' \\\\[\\n'\n", " ' 60x - 600 + 30x = 660\\n'\n", " ' \\\\]\\n'\n", " ' \\\\[\\n'\n", " ' 90x - 600 = 660\\n'\n", " ' \\\\]\\n'\n", " '\\n'\n", " '4. **Solve for \\\\( x \\\\):**\\n'\n", " ' \\\\[\\n'\n", " ' 90x = 660 + 600\\n'\n", " ' \\\\]\\n'\n", " ' \\\\[\\n'\n", " ' 90x = 1260\\n'\n", " ' \\\\]\\n'\n", " ' \\\\[\\n'}]}\n" ] } ], "source": [ "torch.manual_seed(123)\n", "\n", "stats = compute_grpo_loss(\n", " model=model,\n", " tokenizer=tokenizer,\n", " example=math_train[4],\n", " device=device,\n", " num_rollouts=2,\n", " max_new_tokens=256,\n", " temperature=0.8,\n", " top_p=0.9\n", ")\n", "\n", "pprint(stats)" ] }, { "cell_type": "markdown", "id": "fb2ff6a3-970b-425e-9bb8-7ee2c192b2b3", "metadata": {}, "source": [ " \n", "## 6.11 Implementing the GRPO training loop" ] }, { "cell_type": "markdown", "id": "54d62ad8-1612-4523-b47b-d21f90ea1301", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "0361c9fe-006b-4fee-b509-18248201eb01", "metadata": {}, "source": [ "- We skip batching due to the already expensive resource requirements" ] }, { "cell_type": "markdown", "id": "4fd75007-c0a1-45f5-b9d5-094fe19d026a", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 25, "id": "80998a74-6575-4706-8962-5b188f98b8eb", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "def train_rlvr_grpo(\n", " model,\n", " tokenizer,\n", " math_data,\n", " device,\n", " steps=None,\n", " num_rollouts=2,\n", " max_new_tokens=256,\n", " temperature=0.8,\n", " top_p=0.9,\n", " lr=1e-5,\n", " checkpoint_every=50,\n", " checkpoint_dir=\".\",\n", " csv_log_path=None,\n", "\n", "):\n", " if steps is None:\n", " steps = len(math_data)\n", "\n", " # Stage 1: initialize optimizer\n", " # (the model was already initialized outside the function)\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", " model.train()\n", " current_step = 0\n", " if csv_log_path is None:\n", " timestamp = time.strftime(\"%Y%m%d_%H%M%S\")\n", " csv_log_path = f\"train_rlvr_grpo_metrics_{timestamp}.csv\"\n", " csv_log_path = Path(csv_log_path)\n", "\n", " try:\n", " # Stage 2: Iterate over training steps\n", " for step in range(steps):\n", "\n", " # Stage 3: Reset loss gradient\n", " # (it's best practice to do this at the beginning of each step)\n", " optimizer.zero_grad()\n", "\n", " current_step = step + 1\n", " example = math_data[step % len(math_data)]\n", "\n", " # Stage 4: calculate GRPO loss\n", " stats = compute_grpo_loss(\n", " model=model,\n", " tokenizer=tokenizer,\n", " example=example,\n", " device=device,\n", " num_rollouts=num_rollouts,\n", " max_new_tokens=max_new_tokens,\n", " temperature=temperature,\n", " top_p=top_p,\n", " )\n", "\n", " # Stage 5: Backward pass to calculate loss gradients\n", " stats[\"loss_tensor\"].backward()\n", "\n", " # Clip large gradients to improve training stability\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", "\n", " # Stage 6: Update model weights using loss gradients\n", " optimizer.step()\n", "\n", " # Stage 7: Collect rewards, response lengths, and losses\n", " reward_avg = torch.tensor(stats[\"rewards\"]).mean().item()\n", " step_tokens = sum(\n", " sample[\"gen_len\"] for sample in stats[\"samples\"]\n", " )\n", " avg_response_len = (\n", " step_tokens / len(stats[\"samples\"]) \n", " if stats[\"samples\"] else 0.0\n", " )\n", " append_csv_metrics(\n", " csv_log_path, current_step, steps, stats[\"loss\"],\n", " reward_avg, avg_response_len,\n", " )\n", "\n", " # Print step metrics\n", " print(\n", " f\"[Step {current_step}/{steps}] \"\n", " f\"loss={stats['loss']:.4f} \"\n", " f\"reward_avg={reward_avg:.3f} \"\n", " f\"avg_resp_len={avg_response_len:.1f}\"\n", " )\n", "\n", " # Sample outputs (every 10 steps) to check if model\n", " # generates coherent text\n", " if current_step % 10 == 0:\n", " print(f\"[Step {current_step}] sample outputs\")\n", " for i, sample in enumerate(stats[\"samples\"][:3]):\n", " text = sample[\"text\"].replace(\"\\n\", \"\\\\n\")\n", " print(\n", " f\" {i+1}) reward={sample['reward']:.3f} \"\n", " f\"len={sample['gen_len']}: {text}\"\n", " )\n", " print()\n", "\n", " # Stage 8: Save model checkpoint\n", " if checkpoint_every and current_step % checkpoint_every == 0:\n", " ckpt_path = save_checkpoint(\n", " model=model,\n", " checkpoint_dir=checkpoint_dir,\n", " step=current_step,\n", " )\n", " print(f\"Saved checkpoint to {ckpt_path}\")\n", "\n", " # Save a model checkpoint if we interrupt the training early\n", " except KeyboardInterrupt:\n", " ckpt_path = save_checkpoint(\n", " model=model,\n", " checkpoint_dir=checkpoint_dir,\n", " step=max(1, current_step),\n", " suffix=\"interrupt\",\n", " )\n", " print(f\"\\nKeyboardInterrupt. Saved checkpoint to {ckpt_path}\")\n", " return model\n", "\n", " return model\n", "\n", "\n", "def save_checkpoint(model, checkpoint_dir, step, suffix=\"\"):\n", " checkpoint_dir = Path(checkpoint_dir)\n", " checkpoint_dir.mkdir(parents=True, exist_ok=True)\n", " suffix = f\"-{suffix}\" if suffix else \"\"\n", " ckpt_path = (\n", " checkpoint_dir /\n", " f\"qwen3-0.6B-rlvr-grpo-step{step:05d}{suffix}.pth\"\n", " )\n", " torch.save(model.state_dict(), ckpt_path)\n", " return ckpt_path\n", "\n", "\n", "def append_csv_metrics(\n", " csv_log_path,\n", " step_idx,\n", " total_steps,\n", " loss,\n", " reward_avg,\n", " avg_response_len,\n", "):\n", " if not csv_log_path.exists():\n", " csv_log_path.write_text(\n", " \"step,total_steps,loss,reward_avg,avg_response_len\\n\",\n", " encoding=\"utf-8\",\n", " )\n", " with csv_log_path.open(\"a\", encoding=\"utf-8\") as f:\n", " f.write(\n", " f\"{step_idx},{total_steps},{loss:.6f},{reward_avg:.6f},\"\n", " f\"{avg_response_len:.6f}\\n\"\n", " )" ] }, { "cell_type": "markdown", "id": "c0673027-8a8b-4cc1-b634-d33fefd3be4c", "metadata": {}, "source": [ "- Everything except for stage 4, the GRPO loss calculation, is part of the standard training loop when training deep neural networks (including LLMs)\n", "- The `append_csv_metrics` records the results in a CSV file for record keeping (and to visualize the results in chapter 7)\n", "- For a general introduction to training neural networks in PyTorch, please see sections 3-8 in my [PyTorch in One Hour: From Tensors to Training Neural Networks on Multiple GPUs](https://sebastianraschka.com/teaching/pytorch-1h/) article" ] }, { "cell_type": "code", "execution_count": 26, "id": "2e8246b0-78de-4e06-afb5-e7faea4517b8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using Apple Silicon GPU (MPS)\n", "[Step 1/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=88.0\n", "[Step 2/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=7.5\n", "[Step 3/50] loss=-0.0000 reward_avg=0.000 avg_resp_len=6.5\n", "[Step 4/50] loss=0.0909 reward_avg=0.250 avg_resp_len=6.5\n", "[Step 5/50] loss=1.1001 reward_avg=0.500 avg_resp_len=300.5\n", "Saved checkpoint to qwen3-0.6B-rlvr-grpo-step00005.pth\n", "\n", "KeyboardInterrupt. Saved checkpoint to qwen3-0.6B-rlvr-grpo-step00006-interrupt.pth\n" ] }, { "data": { "text/plain": [ "Qwen3Model(\n", " (tok_emb): Embedding(151936, 1024)\n", " (trf_blocks): ModuleList(\n", " (0-27): 28 x TransformerBlock(\n", " (att): GroupedQueryAttention(\n", " (W_query): Linear(in_features=1024, out_features=2048, bias=False)\n", " (W_key): Linear(in_features=1024, out_features=1024, bias=False)\n", " (W_value): Linear(in_features=1024, out_features=1024, bias=False)\n", " (out_proj): Linear(in_features=2048, out_features=1024, bias=False)\n", " (q_norm): RMSNorm()\n", " (k_norm): RMSNorm()\n", " )\n", " (ff): FeedForward(\n", " (fc1): Linear(in_features=1024, out_features=3072, bias=False)\n", " (fc2): Linear(in_features=1024, out_features=3072, bias=False)\n", " (fc3): Linear(in_features=3072, out_features=1024, bias=False)\n", " )\n", " (norm1): RMSNorm()\n", " (norm2): RMSNorm()\n", " )\n", " )\n", " (final_norm): RMSNorm()\n", " (out_head): Linear(in_features=1024, out_features=151936, bias=False)\n", ")" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = get_device()\n", "model.to(device)\n", "\n", "torch.manual_seed(0)\n", "\n", "train_rlvr_grpo(\n", " model=model,\n", " tokenizer=tokenizer,\n", " math_data=math_train,\n", " device=device,\n", " steps=50,\n", " num_rollouts=4,\n", " max_new_tokens=512,\n", " temperature=0.8,\n", " top_p=0.9,\n", " lr=1e-5,\n", " checkpoint_every=5,\n", " checkpoint_dir=\".\",\n", " csv_log_path=\"train_rlvr_grpo_metrics.csv\",\n", ")" ] }, { "cell_type": "markdown", "id": "86a39cb8-b739-4343-8e4c-7161fcd35368", "metadata": {}, "source": [ "- If you have memory-related issues when running the code above, you can lower the number of rollouts (e.g., `num_rollouts=2`) and number of tokens per rollout (e.g., `max_new_tokens=128`)\n", "- However, to get a relatively good model, it requires at least `num_rollouts=8` and `max_new_tokens=512`\n", "- If you can't run it on your available hardware, no worries, the next section shows how to download a pre-trained checkpoint" ] }, { "cell_type": "markdown", "id": "663e1cc7-77ee-4c3a-9e75-75a58865f487", "metadata": {}, "source": [ "- Note that either way, the code will likely run very slowly, because GRPO is a resource-intensive procedure\n", "- You can interrupt the run anytime, and it will save the latest model checkpoint in the `checkpoints` folder\n", "- If you are interested in using cloud GPUs, please see the [GPU Cloud Resources](../../ch02/02_setup-tips/gpu-instructions.md) document for recommendations" ] }, { "cell_type": "markdown", "id": "37159780-2208-41ac-82c0-b2f86ffa2058", "metadata": {}, "source": [ "- Note that this code does not support batched training\n", "- This is a deliberate choice to keep the code simpler and more readable, and because sampling multiple (potentially long) rollouts can already be very resource-intensive\n", "- However, if you have access to multiple GPUs, you can use the optional version of this code with batch and multi-GPU support that can be found in the supplementary materials at [../02_rlvr_grpo_scripts_intro](../02_rlvr_grpo_scripts_intro), which trains the model faster" ] }, { "cell_type": "markdown", "id": "77908793-9356-4baf-9925-f9e1f7579319", "metadata": {}, "source": [ " \n", "## 6.12 Loading and evaluating saved model checkpoints" ] }, { "cell_type": "markdown", "id": "bdaaa2ad-d5f8-4e3d-9314-b9fe613e9e56", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "f1a50988-dc08-4769-bb04-82b0828a655c", "metadata": {}, "source": [ "- The saved checkpoints can be loaded using the model.load_state_dict(torch.load(model_path)) explained in chapter 2, where model_path references the checkpoint \".pth\" file\n", "- These checkpoint files are also compatible with the model evaluation utilities from chapter 3\n", "- For your convenience, you can use the evaluation scripts provided in the chapter 3's bonus materials:\n", "\n", "```python\n", "uv run ../../ch03/02_math500-verifier-scripts/evaluate_math500.py \\\n", "--dataset_size 500 \\\n", "--which_model base \\\n", "--checkpoint_path checkpoints/qwen3-0.6B-rlvr-grpo-step00050.pth\n", "```\n", " \n", "\n", "- If you prefer not to run the GRPO training on your computer because it takes too long, you can also download the checkpoints that I uploaded to [rasbt/qwen3-from-scratch-grpo-checkpoints/tree/main/grpo_original_no_kl](https://huggingface.co/rasbt/qwen3-from-scratch-grpo-checkpoints/tree/main/grpo_original_no_kl) (click on the checkpoint file you want to download and then click the [download](https://huggingface.co/rasbt/qwen3-from-scratch-grpo-checkpoints/resolve/main/grpo_original_no_kl/qwen3-0.6B-rlvr-grpo-step00050.pth?download=true) button\n", "- For your convenience, you can also download the checkpoint directly here using Python" ] }, { "cell_type": "code", "execution_count": 27, "id": "96e07312-94e0-409c-b477-aa7da8f974ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✓ qwen3-0.6B-rlvr-grpo-step00050.pth already up-to-date\n" ] } ], "source": [ "from reasoning_from_scratch.qwen3 import download_qwen3_grpo_checkpoints\n", "\n", "download_qwen3_grpo_checkpoints(grpo_type=\"no_kl\", step=\"00050\")" ] }, { "cell_type": "markdown", "id": "3fa9fdbd-041b-411e-b2f5-65ecedb1f53c", "metadata": {}, "source": [ "| | Method | Step | Max tokens | Num rollouts | MATH-500 Acc | Avg # of tokens |\n", "| ---- | -------------------------------------- | ---- | ---------- | ------------ | ------------ | --------------- |\n", "| 1 | Base (chapter 3) | - | | | 15.2% | 78.85 |\n", "| 2 | Reasoning (chapter 3) | - | | | 48.2% | 1369.79 |\n", "| 3 | GRPO original but no KL (this chapter) | 50 | 512 | 8 | 47.4% | 586.11 | " ] }, { "cell_type": "markdown", "id": "a8935c79-a76a-49a0-8158-9777e6ea6cf3", "metadata": {}, "source": [ "- Based on the table above, we see that after only 50 steps, the trained model (row 3), which is initialized from the base model (row 1), is almost as good as the original reasoning variant (row 2)\n", "- Note that training for longer may not improve the model and could even make it worse, as GRPO can be relatively unstable; the next chapter introduces additional tricks to improve the GRPO algorithm" ] }, { "cell_type": "markdown", "id": "55cce81d-0cac-4e0e-96e5-2778f836d9f9", "metadata": {}, "source": [ " \n", "## 6.13 Summary" ] }, { "cell_type": "markdown", "id": "a26e0d6e-81f0-4a22-9277-8918adbb4a76", "metadata": {}, "source": [ "- No code in this section" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }