{ "cells": [ { "cell_type": "markdown", "id": "9bba24bf-3592-47d2-bfb1-5177324a418e", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Reasoning Model (From Scratch) book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/reasoning-from-scratch\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "90fa0a7f-b86b-4a92-9957-18f8a4398290", "metadata": {}, "source": [ "# Chapter 3: Evaluating Reasoning Models" ] }, { "cell_type": "code", "execution_count": 1, "id": "d2c83184-31d0-4bcd-a7ea-67ee366736ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reasoning_from_scratch version: 0.1.10\n", "torch version: 2.9.0\n", "sympy version: 1.14.0\n", "tokenizers version: 0.21.4\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "used_libraries = [\n", " \"reasoning_from_scratch\",\n", " \"torch\",\n", " \"sympy\",\n", " \"tokenizers\" # Used by reasoning_from_scratch\n", "]\n", "\n", "for lib in used_libraries:\n", " print(f\"{lib} version: {version(lib)}\")" ] }, { "cell_type": "markdown", "id": "eb4aa7b8-25bc-4a1a-978c-19cfdc059d97", "metadata": {}, "source": [ "
\n", "\n", "" ] }, { "cell_type": "markdown", "id": "65ecfa9d-b502-4ec5-b80c-e67142273718", "metadata": {}, "source": [ " \n", "## 3.1 Building a math verifier" ] }, { "cell_type": "markdown", "id": "17412c95-a620-4b3f-978f-39525dba7fd9", "metadata": {}, "source": [ "- No code in this section" ] }, { "cell_type": "markdown", "id": "f815d4c6-71ae-4d21-80a3-5451822d6bd3", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "0a65814d-2497-4270-aa02-14efa8a05658", "metadata": {}, "source": [ "
\n", "
\n", "
\n", "\n", "" ] }, { "cell_type": "markdown", "id": "57a963b9-5ae8-4471-8d80-c2295e034465", "metadata": {}, "source": [ " \n", "## 3.2 Loading a pre-trained model to generate text" ] }, { "cell_type": "markdown", "id": "387f8c7e-6508-494b-a233-1edee5c2649f", "metadata": {}, "source": [ "- In this section, we load the model (recap of chapter 2) that we want to evaluate" ] }, { "cell_type": "code", "execution_count": 2, "id": "242ef7f6-c9ac-49b0-bd57-a09d325f4dbc", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import torch\n", "\n", "\n", "from reasoning_from_scratch.qwen3 import (\n", " download_qwen3_small,\n", " Qwen3Tokenizer,\n", " Qwen3Model,\n", " QWEN_CONFIG_06_B\n", ")\n", "\n", "\n", "def load_model_and_tokenizer(\n", " which_model, device, use_compile, local_dir=\"qwen3\"\n", "):\n", " if which_model == \"base\":\n", "\n", " download_qwen3_small(\n", " kind=\"base\", tokenizer_only=False, out_dir=local_dir\n", " )\n", "\n", " tokenizer_path = Path(local_dir) / \"tokenizer-base.json\"\n", " model_path = Path(local_dir) / \"qwen3-0.6B-base.pth\"\n", " tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)\n", "\n", " elif which_model == \"reasoning\":\n", "\n", " download_qwen3_small(\n", " kind=\"reasoning\", tokenizer_only=False, out_dir=local_dir\n", " )\n", "\n", " tokenizer_path = Path(local_dir) / \"tokenizer-reasoning.json\"\n", " model_path = Path(local_dir) / \"qwen3-0.6B-reasoning.pth\"\n", " tokenizer = Qwen3Tokenizer(\n", " tokenizer_file_path=tokenizer_path,\n", " apply_chat_template=True,\n", " add_generation_prompt=True,\n", " add_thinking=True,\n", " )\n", "\n", " else:\n", " raise ValueError(f\"Invalid choice: which_model={which_model}\")\n", "\n", " model = Qwen3Model(QWEN_CONFIG_06_B)\n", " model.load_state_dict(torch.load(model_path))\n", "\n", " model.to(device)\n", "\n", " if use_compile:\n", " torch._dynamo.config.allow_unspec_int_on_nn_module = True\n", " model = torch.compile(model)\n", "\n", " return model, tokenizer" ] }, { "cell_type": "markdown", "id": "08d43aef-fe2b-429a-9730-0464a9206df4", "metadata": {}, "source": [ "- Note that we use the base model here; once you have completed this chapter, you can rerun the notebook after changing `which_model=\"base\"` to `which_model=\"reasoning\"` to evaluate an already trained reasoning model" ] }, { "cell_type": "code", "execution_count": 3, "id": "348b3522-dcc9-4341-bc4e-e99ff95b952e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using Apple Silicon GPU (MPS)\n", "✓ qwen3/qwen3-0.6B-base.pth already up-to-date\n" ] } ], "source": [ "from reasoning_from_scratch.ch02 import (\n", " get_device\n", ")\n", "\n", "WHICH_MODEL = \"base\"\n", "device = get_device()\n", "\n", "# If you have compatibility issues, try to\n", "# uncomment the line below and rerun the notebook\n", "# device = torch.device(\"cpu\")\n", "\n", "model, tokenizer = load_model_and_tokenizer(\n", " which_model=WHICH_MODEL,\n", " device=device,\n", " use_compile=False\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "0e7082a7-3013-4489-9d87-a2e1790f1009", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " To find the value of \\( a^2 + b^2 \\) given that \\( a + b = 3 \\) and \\( ab = \\frac{13}{6} \\), we can use the following algebraic identity:\n", "\n", "\\[\n", "a^2 + b^2 = (a + b)^2 - 2ab\n", "\\]\n", "\n", "**Step 1:** Substitute the given values into the equation.\n", "\n", "\\[\n", "a^2 + b^2 = (3)^2 - 2 \\left( \\frac{13}{6} \\right)\n", "\\]\n", "\n", "**Step 2:** Calculate \\( (3)^2 \\).\n", "\n", "\\[\n", "(3)^2 = 9\n", "\\]\n", "\n", "**Step 3:** Calculate \\( 2 \\times \\frac{13}{6} \\).\n", "\n", "\\[\n", "2 \\times \\frac{13}{6} = \\frac{26}{6} = \\frac{13}{3}\n", "\\]\n", "\n", "**Step 4:** Subtract the second result from the first.\n", "\n", "\\[\n", "a^2 + b^2 = 9 - \\frac{13}{3}\n", "\\]\n", "\n", "**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.\n", "\n", "\\[\n", "9 = \\frac{27}{3}\n", "\\]\n", "\n", "\\[\n", "a^2 + b^2 = \\frac{27}{3} - \\frac{13}{3} = \\frac{14}{3}\n", "\\]\n", "\n", "**Final Answer:**\n", "\n", "\\[\n", "\\boxed{\\dfrac{14}{3}}\n", "\\]" ] } ], "source": [ "from reasoning_from_scratch.ch02 import (\n", " generate_text_basic_stream_cache\n", ")\n", "\n", "prompt = (\n", " r\"If $a+b=3$ and $ab=\\tfrac{13}{6}$, \"\n", " r\"what is the value of $a^2+b^2$?\"\n", ")\n", "\n", "# Similar to chapter 2 exercise solution:\n", "input_token_ids_tensor = torch.tensor(\n", " tokenizer.encode(prompt),\n", " device=device\n", " ).unsqueeze(0)\n", "\n", "all_token_ids = []\n", "for token in generate_text_basic_stream_cache(\n", " model=model,\n", " token_ids=input_token_ids_tensor,\n", " max_new_tokens=2048,\n", " eos_token_id=tokenizer.eos_token_id\n", "):\n", " token_id = token.squeeze(0)\n", " decoded_id = tokenizer.decode(token_id.tolist())\n", " print(\n", " decoded_id,\n", " end=\"\",\n", " flush=True\n", " )\n", " all_token_ids.append(token_id)\n", "\n", "all_tokens = tokenizer.decode(all_token_ids)" ] }, { "cell_type": "markdown", "id": "d7300363-1109-4bed-9aeb-24d3d4c6c745", "metadata": {}, "source": [ "- If you are unfamiliar with LaTeX syntax, the response above can be very hard to read\n", "- You can use the `Latex` class to render the LaTeX syntax to improve readability, as shown below" ] }, { "cell_type": "code", "execution_count": 5, "id": "3fbbd72a-5dce-4d99-bcc3-b83645affeab", "metadata": {}, "outputs": [ { "data": { "text/latex": [ " To find the value of \\( a^2 + b^2 \\) given that \\( a + b = 3 \\) and \\( ab = \\frac{13}{6} \\), we can use the following algebraic identity:\n", "\n", "\\[\n", "a^2 + b^2 = (a + b)^2 - 2ab\n", "\\]\n", "\n", "**Step 1:** Substitute the given values into the equation.\n", "\n", "\\[\n", "a^2 + b^2 = (3)^2 - 2 \\left( \\frac{13}{6} \\right)\n", "\\]\n", "\n", "**Step 2:** Calculate \\( (3)^2 \\).\n", "\n", "\\[\n", "(3)^2 = 9\n", "\\]\n", "\n", "**Step 3:** Calculate \\( 2 \\times \\frac{13}{6} \\).\n", "\n", "\\[\n", "2 \\times \\frac{13}{6} = \\frac{26}{6} = \\frac{13}{3}\n", "\\]\n", "\n", "**Step 4:** Subtract the second result from the first.\n", "\n", "\\[\n", "a^2 + b^2 = 9 - \\frac{13}{3}\n", "\\]\n", "\n", "**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.\n", "\n", "\\[\n", "9 = \\frac{27}{3}\n", "\\]\n", "\n", "\\[\n", "a^2 + b^2 = \\frac{27}{3} - \\frac{13}{3} = \\frac{14}{3}\n", "\\]\n", "\n", "**Final Answer:**\n", "\n", "\\[\n", "\\boxed{\\dfrac{14}{3}}\n", "\\]" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Latex, display\n", "\n", "display(Latex(all_tokens))" ] }, { "cell_type": "markdown", "id": "7e0ab369-29b6-4145-9ae6-4341d7db3cbf", "metadata": {}, "source": [ "- If you only want to render specific math expressions, you can also use the `Math` class:" ] }, { "cell_type": "code", "execution_count": 6, "id": "a348e6e5-0f74-4061-9c3c-1b06d8171b89", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle \\dfrac{14}{3}$" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from IPython.display import Math\n", "\n", "display(Math(r\"\\dfrac{14}{3}\"))" ] }, { "cell_type": "markdown", "id": "1f443a99-4eb9-40cd-a25e-c1f4f054aba0", "metadata": {}, "source": [ " \n", "## 3.3 Implementing a wrapper for easier text generation" ] }, { "cell_type": "markdown", "id": "566759b7-c950-4fa5-abd6-e967e91e497c", "metadata": {}, "source": [ "- Above, we loaded the pre-trained LLM and set up the text generation functionality (as illustrated in the figure below), which are the first two steps of the evaluation process covered in the remainder of this chapter" ] }, { "cell_type": "markdown", "id": "69e31a7b-c37a-40a6-95cc-fffe5a17a33e", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "a2f4ac6b-31b3-49af-9b80-bb479c124dbd", "metadata": {}, "source": [ "- For additional convenience, we create a wrapper function for the text generation function so that we only have to pass in the model, tokenizer, and prompt, along with some additional settings" ] }, { "cell_type": "code", "execution_count": 7, "id": "628423fc-87d3-4ae6-b453-399106f4173e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " To find the value of \\( a^2 + b^2 \\) given that \\( a + b = 3 \\) and \\( ab = \\frac{13}{6} \\), we can use the following algebraic identity:\n", "\n", "\\[\n", "a^2 + b^2 = (a + b)^2 - 2ab\n", "\\]\n", "\n", "**Step 1:** Substitute the given values into the equation.\n", "\n", "\\[\n", "a^2 + b^2 = (3)^2 - 2 \\left( \\frac{13}{6} \\right)\n", "\\]\n", "\n", "**Step 2:** Calculate \\( (3)^2 \\).\n", "\n", "\\[\n", "(3)^2 = 9\n", "\\]\n", "\n", "**Step 3:** Calculate \\( 2 \\times \\frac{13}{6} \\).\n", "\n", "\\[\n", "2 \\times \\frac{13}{6} = \\frac{26}{6} = \\frac{13}{3}\n", "\\]\n", "\n", "**Step 4:** Subtract the second result from the first.\n", "\n", "\\[\n", "a^2 + b^2 = 9 - \\frac{13}{3}\n", "\\]\n", "\n", "**Step 5:** Convert 9 to a fraction with a denominator of 3 to perform the subtraction.\n", "\n", "\\[\n", "9 = \\frac{27}{3}\n", "\\]\n", "\n", "\\[\n", "a^2 + b^2 = \\frac{27}{3} - \\frac{13}{3} = \\frac{14}{3}\n", "\\]\n", "\n", "**Final Answer:**\n", "\n", "\\[\n", "\\boxed{\\dfrac{14}{3}}\n", "\\]" ] } ], "source": [ "def generate_text_stream_concat(\n", " model, tokenizer, prompt, device, max_new_tokens,\n", " verbose=False,\n", "):\n", " input_ids = torch.tensor(\n", " tokenizer.encode(prompt), device=device\n", " ).unsqueeze(0)\n", "\n", " generated_ids = []\n", " for token in generate_text_basic_stream_cache(\n", " model=model,\n", " token_ids=input_ids,\n", " max_new_tokens=max_new_tokens,\n", " eos_token_id=tokenizer.eos_token_id,\n", " ):\n", " next_token_id = token.squeeze(0)\n", " generated_ids.append(next_token_id.item())\n", "\n", " if verbose:\n", " print(\n", " tokenizer.decode(next_token_id.tolist()),\n", " end=\"\",\n", " flush=True\n", " )\n", " return tokenizer.decode(generated_ids)\n", "\n", "\n", "skip_portion = False\n", "\n", "if not skip_portion:\n", " generated_text = generate_text_stream_concat(\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=2048,\n", " verbose=True\n", " )" ] }, { "cell_type": "markdown", "id": "74f3939b-03f6-403f-9ee9-1ccd91f4fc19", "metadata": {}, "source": [ " \n", "## 3.4 Extracting the final answer box" ] }, { "cell_type": "markdown", "id": "5180c33c-a766-450b-a237-da1173174403", "metadata": {}, "source": [ "- In this section, we extract the answer box (step 3); later, in the next section will take the extracted answer and normalize it (step 4)" ] }, { "cell_type": "markdown", "id": "0f249918-b1e7-478a-ba90-876569805fba", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 8, "id": "188fec94-da7d-41ba-a101-62faed84ae48", "metadata": {}, "outputs": [], "source": [ "model_answer = (\n", "r\"\"\"... some explanation...\n", "**Final Answer:**\n", "\n", "\\[\n", "\\boxed{\\dfrac{14}{3}}\n", "\\]\n", "\"\"\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "39a18b82-a592-49d2-8ae3-9db949fd3c1b", "metadata": {}, "outputs": [], "source": [ "def get_last_boxed(text):\n", " # Find the last occurrence of \"\\boxed\"\n", " boxed_start_idx = text.rfind(r\"\\boxed\")\n", " if boxed_start_idx == -1:\n", " return None\n", "\n", " # Get position after \"\\boxed\"\n", " current_idx = boxed_start_idx + len(r\"\\boxed\")\n", "\n", " # Skip any whitespace after \"\\boxed\"\n", " while current_idx < len(text) and text[current_idx].isspace():\n", " current_idx += 1\n", "\n", " # Expect an opening brace \"{\"\n", " if current_idx >= len(text) or text[current_idx] != \"{\":\n", " return None\n", "\n", " # Parse the braces with nesting\n", " current_idx += 1\n", " brace_depth = 1\n", " content_start_idx = current_idx\n", "\n", " while current_idx < len(text) and brace_depth > 0:\n", " char = text[current_idx]\n", " if char == \"{\":\n", " brace_depth += 1\n", " elif char == \"}\":\n", " brace_depth -= 1\n", " current_idx += 1\n", "\n", " # Account for unbalanced braces\n", " if brace_depth != 0:\n", " return None\n", "\n", " # Extract content inside the outermost braces\n", " return text[content_start_idx:current_idx-1]" ] }, { "cell_type": "code", "execution_count": 10, "id": "4c24fe48-80b5-4608-89f0-0e4afe7ab123", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\dfrac{14}{3}\n" ] } ], "source": [ "extracted_answer = get_last_boxed(model_answer)\n", "print(extracted_answer)" ] }, { "cell_type": "code", "execution_count": 11, "id": "9fe03f2f-ff52-4009-a241-bf3eef68e4fc", "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "RE_NUMBER = re.compile(\n", " r\"-?(?:\\d+/\\d+|\\d+(?:\\.\\d+)?(?:[eE][+-]?\\d+)?)\"\n", ")\n", "\n", "def extract_final_candidate(text, fallback=\"number_then_full\"):\n", " # Default return value if nothing matches\n", " result = \"\"\n", "\n", " if text:\n", " # Prefer the last boxed expression if present\n", " boxed = get_last_boxed(text.strip())\n", " if boxed:\n", " result = boxed.strip().strip(\"$ \")\n", "\n", " # If no boxed expression, try fallback\n", " elif fallback in (\"number_then_full\", \"number_only\"):\n", " m = RE_NUMBER.findall(text)\n", " if m:\n", " # Use last number\n", " result = m[-1]\n", " elif fallback == \"number_then_full\":\n", " # Else return full text if no number found\n", " result = text\n", " return result" ] }, { "cell_type": "markdown", "id": "537225af-0184-48a5-a5a6-fc223c401231", "metadata": {}, "source": [ "- fallback settings if no boxed content is found:\n", " - \"number_then_full\": pick the last simple number, else the whole text\n", " - \"number_only\": pick the last simple number, else return an empty string `\"\"`\n", " - \"none\": extract only boxed content, else return empty string `\"\"`" ] }, { "cell_type": "code", "execution_count": 12, "id": "5ee81df6-dafb-4bdf-8608-04a871121f06", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\\dfrac{14}{3}\n" ] } ], "source": [ "print(extract_final_candidate(model_answer))" ] }, { "cell_type": "code", "execution_count": 13, "id": "64b63d7b-db95-465a-ba64-d476fce63ad9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14/3.\n" ] } ], "source": [ "print(extract_final_candidate(r\"\\boxed{ 14/3. }\"))" ] }, { "cell_type": "code", "execution_count": 14, "id": "07570c56-fb1f-45d6-a118-a4e71e116226", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14/3\n" ] } ], "source": [ "print(extract_final_candidate(\"abc < > 14/3 abc\"))" ] }, { "cell_type": "code", "execution_count": 15, "id": "570ebcc5-f682-4f70-a018-7529e1402575", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Text without numbers\n" ] } ], "source": [ "print(extract_final_candidate(\"Text without numbers\"))" ] }, { "cell_type": "markdown", "id": "45892c39-d78d-4ee2-9811-25fb8d4a04d7", "metadata": {}, "source": [ " \n", "## 3.5 Normalizing the extracted answer" ] }, { "cell_type": "markdown", "id": "f92fad0e-72fa-4bab-bbb4-e0010d26c8b7", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "b870b559-3e14-44b7-91ce-ba01a6e08aab", "metadata": {}, "source": [ "- In the previous section, we extracted the answer (step 3), now we are normalizing it (step 4 in the previous figure)" ] }, { "cell_type": "code", "execution_count": 16, "id": "2a5a68e1-941a-4f3a-97f0-b83393aa923a", "metadata": {}, "outputs": [], "source": [ "LATEX_FIXES = [ # Latex formatting to be replaced\n", " (r\"\\\\left\\s*\", \"\"),\n", " (r\"\\\\right\\s*\", \"\"),\n", " (r\"\\\\,|\\\\!|\\\\;|\\\\:\", \"\"),\n", " (r\"\\\\cdot\", \"*\"),\n", " (r\"\\u00B7|\\u00D7\", \"*\"),\n", " (r\"\\\\\\^\\\\circ\", \"\"),\n", " (r\"\\\\dfrac\", r\"\\\\frac\"),\n", " (r\"\\\\tfrac\", r\"\\\\frac\"),\n", " (r\"°\", \"\"),\n", "]\n", "\n", "RE_SPECIAL = re.compile(r\"<\\|[^>]+?\\|>\") # strip chat special tokens like <|assistant|>\n", "\n", "def normalize_text(text):\n", " if not text:\n", " return \"\"\n", " text = RE_SPECIAL.sub(\"\", text).strip()\n", " SUPERSCRIPT_MAP = {\n", " \"⁰\": \"0\", \"¹\": \"1\", \"²\": \"2\", \"³\": \"3\", \"⁴\": \"4\",\n", " \"⁵\": \"5\", \"⁶\": \"6\", \"⁷\": \"7\", \"⁸\": \"8\", \"⁹\": \"9\",\n", " \"⁺\": \"+\", \"⁻\": \"-\", \"⁽\": \"(\", \"⁾\": \")\",\n", " }\n", "\n", " # Strip leading multiple-choice labels\n", " # E.g., like \"c. 3\" -> 3, or \"b: 2\" -> 2\n", " match = re.match(r\"^[A-Za-z]\\s*[.:]\\s*(.+)$\", text)\n", " if match:\n", " text = match.group(1)\n", " \n", " # Remove angle-degree markers\n", " text = re.sub(r\"\\^\\s*\\{\\s*\\\\circ\\s*\\}\", \"\", text) # ^{\\circ}\n", " text = re.sub(r\"\\^\\s*\\\\circ\", \"\", text) # ^\\circ\n", " text = text.replace(\"°\", \"\") # Unicode degree\n", "\n", " # unwrap \\text{...} if the whole string is wrapped\n", " match = re.match(r\"^\\\\text\\{(?P.+?)\\}$\", text)\n", " if match:\n", " text = match.group(\"x\")\n", "\n", " # strip inline/display math wrappers \\( \\) \\[ \\]\n", " text = re.sub(r\"\\\\\\(|\\\\\\)|\\\\\\[|\\\\\\]\", \"\", text)\n", "\n", " # light LaTeX canonicalization\n", " for pat, rep in LATEX_FIXES:\n", " text = re.sub(pat, rep, text)\n", "\n", " def convert_superscripts(s, base=None):\n", " converted = \"\".join(\n", " SUPERSCRIPT_MAP[ch] if ch in SUPERSCRIPT_MAP else ch\n", " for ch in s\n", " )\n", " if base is None:\n", " return converted\n", " return f\"{base}**{converted}\"\n", "\n", " # convert unicode superscripts into exponent form (e.g., 2² -> 2**2)m\n", " text = re.sub(\n", " r\"([0-9A-Za-z\\)\\]\\}])([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)\",\n", " lambda m: convert_superscripts(m.group(2), base=m.group(1)),\n", " text,\n", " )\n", " text = convert_superscripts(text)\n", " \n", " # numbers/roots\n", " text = text.replace(\"\\\\%\", \"%\").replace(\"$\", \"\").replace(\"%\", \"\")\n", " text = re.sub(\n", " r\"\\\\sqrt\\s*\\{([^}]*)\\}\",\n", " lambda match: f\"sqrt({match.group(1)})\",\n", " text,\n", " )\n", " text = re.sub(\n", " r\"\\\\sqrt\\s+([^\\\\\\s{}]+)\",\n", " lambda match: f\"sqrt({match.group(1)})\",\n", " text,\n", " )\n", "\n", " # fractions\n", " text = re.sub(\n", " r\"\\\\frac\\s*\\{([^{}]+)\\}\\s*\\{([^{}]+)\\}\",\n", " lambda match: f\"({match.group(1)})/({match.group(2)})\",\n", " text,\n", " )\n", " text = re.sub(\n", " r\"\\\\frac\\s+([^\\s{}]+)\\s+([^\\s{}]+)\",\n", " lambda match: f\"({match.group(1)})/({match.group(2)})\",\n", " text,\n", " )\n", "\n", " # exponent and mixed numbers\n", " text = text.replace(\"^\", \"**\")\n", " text = re.sub(\n", " r\"(?<=\\d)\\s+(\\d+/\\d+)\",\n", " lambda match: \"+\" + match.group(1),\n", " text,\n", " )\n", "\n", " # 1,234 -> 1234\n", " text = re.sub(\n", " r\"(?<=\\d),(?=\\d\\d\\d(\\D|$))\",\n", " \"\",\n", " text,\n", " )\n", "\n", " return text.replace(\"{\", \"\").replace(\"}\", \"\").strip().lower()" ] }, { "cell_type": "code", "execution_count": 17, "id": "b1668039-be17-46e0-8477-40307b8235fc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(14)/(3)\n" ] } ], "source": [ "print(normalize_text(extract_final_candidate(model_answer)))" ] }, { "cell_type": "code", "execution_count": 18, "id": "7610b999-3460-4b5b-896b-a24f1582c48d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(14)/(3.)\n" ] } ], "source": [ "print(normalize_text(r\"$\\dfrac{14}{3.}$\"))" ] }, { "cell_type": "code", "execution_count": 19, "id": "f37027f8-8f56-4c8e-90d8-a87f1720732a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(14)/(3)\n" ] } ], "source": [ "print(normalize_text(r\"\\text{\\[\\frac{14}{3}\\]}\"))" ] }, { "cell_type": "code", "execution_count": 20, "id": "5bd95965-1e2c-4926-affb-dff0ed4d35ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4/3\n" ] } ], "source": [ "print(normalize_text(\"4/3\"))" ] }, { "cell_type": "markdown", "id": "027df598-a37b-4951-8d3f-51e4e0435929", "metadata": {}, "source": [ " \n", "## 3.6 Verifying mathematical equivalence" ] }, { "cell_type": "markdown", "id": "a9a0e3e2-54f8-4878-8f1e-2f891775ae69", "metadata": {}, "source": [ "- In this section, we implement the basic functionality to check if the extracted answer (generated by the model) is equivalent to the correct answer (ground truth) provided in the dataset; this is step 5\n", "- In the next section (step 6), we make this process a bit more robust to grade the answer correctness" ] }, { "cell_type": "markdown", "id": "97eb9445-9eb8-464e-9116-eebe86a8a189", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 21, "id": "6f80dc58-c3f4-4d1c-b22a-8059b0b2d264", "metadata": {}, "outputs": [], "source": [ "from sympy.parsing import sympy_parser as spp\n", "from sympy.core.sympify import SympifyError\n", "from tokenize import TokenError\n", "from sympy.polys.polyerrors import PolynomialError\n", "\n", "def sympy_parser(expr):\n", " # To avoid crashing on long garbage responses \n", " # that some badly trained models (chapter 6) may emit\n", " if expr is None or len(expr) > 2000:\n", " return None\n", " try:\n", " return spp.parse_expr(\n", " expr,\n", " transformations=(\n", " # Standard transformations like handling parentheses\n", " *spp.standard_transformations,\n", "\n", " # Allow omitted multiplication symbols (e.g., \"2x\" -> 2*x\")\n", " spp.implicit_multiplication_application,\n", " ),\n", "\n", " # Evaluate during parsing so simple constants simplify (e.g., 2+3 -> 5)\n", " evaluate=True,\n", " )\n", " except (SympifyError, SyntaxError, TypeError, AttributeError,\n", " IndexError, TokenError, ValueError, PolynomialError):\n", " return None" ] }, { "cell_type": "markdown", "id": "7c7b29a0-b446-490e-a8f2-6fcafe598abd", "metadata": {}, "source": [ "- Note that this appears to be an excessive amount of error handling, but these are all errors that I encountered when evaluating the model on all 500 MATH-500 problems as the model does not always generate perfectly formatted outputs" ] }, { "cell_type": "code", "execution_count": 22, "id": "7f9b25c6-1038-4d37-be3f-a6bb7c543f7b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14/3\n" ] } ], "source": [ "print(sympy_parser(normalize_text(\n", " extract_final_candidate(model_answer)\n", ")))" ] }, { "cell_type": "code", "execution_count": 23, "id": "072e1d5f-c785-46ce-ad17-524c2f7a7ef4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14/3\n" ] } ], "source": [ "print(sympy_parser(\"28/6\"))" ] }, { "cell_type": "code", "execution_count": 24, "id": "8c291308-571b-4bce-849d-27f297f6cc75", "metadata": {}, "outputs": [], "source": [ "from sympy import simplify\n", "\n", "def equality_check(expr_gtruth, expr_pred):\n", " # First, check if the two expressions are exactly the same string\n", " if expr_gtruth == expr_pred:\n", " return True\n", "\n", " # Parse both expressions into SymPy objects (returns None if parsing fails)\n", " gtruth, pred = sympy_parser(expr_gtruth), sympy_parser(expr_pred)\n", "\n", " # If both expressions were parsed successfully, try symbolic comparison\n", " if gtruth is not None and pred is not None:\n", " try:\n", " # If the difference is 0, they are equivalent\n", " return simplify(gtruth - pred) == 0\n", " except (SympifyError, TypeError):\n", " pass\n", "\n", " return False" ] }, { "cell_type": "code", "execution_count": 25, "id": "88bd9c22-44b1-4885-9bda-997f6b91ccfd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "print(equality_check(\n", " normalize_text(\"13/4.\"),\n", " normalize_text(r\"(13)/(4)\")\n", "))" ] }, { "cell_type": "code", "execution_count": 26, "id": "ccf2e7ec-2d24-44db-8f4f-c530f460f1b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "print(equality_check(\n", " normalize_text(\"0.5\"),\n", " normalize_text(r\"(1)/(2)\")\n", "))" ] }, { "cell_type": "code", "execution_count": 27, "id": "c19d5b7f-5653-41c4-a4c3-9661ced08b12", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n" ] } ], "source": [ "print(equality_check(\n", " normalize_text(\"14/3\"),\n", " normalize_text(\"15/3\")\n", "))" ] }, { "cell_type": "code", "execution_count": 28, "id": "ad64d868-32de-49d6-b078-7436d54e3252", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n" ] } ], "source": [ "print(equality_check(\n", " normalize_text(\"(14/3, 2/3)\"),\n", " normalize_text(\"(14/3, 4/6)\")\n", "))" ] }, { "cell_type": "markdown", "id": "fe7581df-4461-483b-b612-57c9bdc52032", "metadata": {}, "source": [ " \n", "## 3.7 Grading answers" ] }, { "cell_type": "code", "execution_count": 29, "id": "24b6d454-feb7-4786-85b1-b55f33df840c", "metadata": {}, "outputs": [], "source": [ "def split_into_parts(text):\n", " result = [text]\n", "\n", " if text:\n", " # Check if text looks like a tuple or list, e.g. \"(a, b)\" or \"[a, b]\"\n", " if (\n", " len(text) >= 2\n", " and text[0] in \"([\" and text[-1] in \")]\"\n", " and \",\" in text[1:-1]\n", " ):\n", " # Split on commas inside brackets and strip whitespace\n", " items = [p.strip() for p in text[1:-1].split(\",\")]\n", " if all(items):\n", " result = items\n", " else:\n", " # If text is empty, return an empty list\n", " result = []\n", "\n", " return result" ] }, { "cell_type": "code", "execution_count": 30, "id": "07e3f1f8-6e8c-4440-bda8-96480ca0d5c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['14/3', '2/3']" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "split_into_parts(normalize_text(r\"(14/3, 2/3)\"))" ] }, { "cell_type": "code", "execution_count": 31, "id": "bf17996c-3bd3-496a-a47b-b426b51ab208", "metadata": {}, "outputs": [], "source": [ "def grade_answer(pred_text, gt_text):\n", " result = False # Default outcome if checks fail\n", "\n", " # Only continue if both inputs are non-empty strings\n", " if pred_text is not None and gt_text is not None:\n", " gt_parts = split_into_parts(\n", " normalize_text(gt_text)\n", " ) # Break ground truth into comparable parts\n", "\n", " pred_parts = split_into_parts(\n", " normalize_text(pred_text)\n", " ) # Break prediction into comparable parts\n", "\n", " # Ensure both sides have same number of valid parts\n", " if (gt_parts and pred_parts\n", " and len(gt_parts) == len(pred_parts)):\n", " result = all(\n", " equality_check(gt, pred)\n", " for gt, pred in zip(gt_parts, pred_parts)\n", " ) # Check each part for mathematical equivalence\n", "\n", " return result # True only if all checks passed" ] }, { "cell_type": "code", "execution_count": 32, "id": "6da8ea2e-060a-4055-90db-c7f050677f53", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grade_answer(\"14/3\", r\"\\frac{14}{3}\")" ] }, { "cell_type": "code", "execution_count": 33, "id": "f577aadf-7c42-49ff-a200-032b2f3f0aba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grade_answer(r\"(14/3, 2/3)\", \"(14/3, 4/6)\")" ] }, { "cell_type": "code", "execution_count": 34, "id": "7484d6ea-b5a1-4a3b-b9f1-dcda0dec8f38", "metadata": {}, "outputs": [], "source": [ "# Define test cases: (name, prediction, ground truth, expected result)\n", "tests = [\n", " (\"check_1\", \"3/4\", r\"\\frac{3}{4}\", True),\n", " (\"check_2\", \"(3)/(4)\", r\"3/4\", True),\n", " (\"check_3\", r\"\\frac{\\sqrt{8}}{2}\", \"sqrt(2)\", True),\n", " (\"check_4\", r\"\\( \\frac{1}{2} + \\frac{1}{6} \\)\", \"2/3\", True),\n", " (\"check_5\", \"(1, 2)\", r\"(1,2)\", True),\n", " (\"check_6\", \"(2, 1)\", \"(1, 2)\", False),\n", " (\"check_7\", \"(1, 2, 3)\", \"(1, 2)\", False),\n", " (\"check_8\", \"0.5\", \"1/2\", True),\n", " (\"check_9\", \"0.3333333333\", \"1/3\", False),\n", " (\"check_10\", \"1,234/2\", \"617\", True),\n", " (\"check_11\", r\"\\text{2/3}\", \"2/3\", True),\n", " (\"check_12\", \"50%\", \"1/2\", False),\n", " (\"check_13\", r\"2\\cdot 3/4\", \"3/2\", True),\n", " (\"check_14\", r\"90^\\circ\", \"90\", True),\n", " (\"check_15\", r\"\\left(\\frac{3}{4}\\right)\", \"3/4\", True),\n", " (\"check_16\", r\"2²\", \"2**2\", True),\n", " ]\n", "\n", "\n", "def run_demos_table(tests):\n", " header = (\"Test\", \"Expect\", \"Got\", \"Status\")\n", " rows = []\n", " for name, pred, gtruth, expect in tests:\n", " got = grade_answer(pred, gtruth) # Run equality check\n", " status = \"PASS\" if got == expect else \"FAIL\"\n", " rows.append((name, str(expect), str(got), status))\n", "\n", " data = [header] + rows\n", " \n", " # Compute max width for each column to align table nicely\n", " col_widths = [\n", " max(len(row[i]) for row in data)\n", " for i in range(len(header))\n", " ]\n", "\n", " # Print table row by row\n", " for row in data:\n", " line = \" | \".join(\n", " row[i].ljust(col_widths[i])\n", " for i in range(len(header))\n", " )\n", " print(line)\n", "\n", " # Print summary of passed tests\n", " passed = sum(r[3] == \"PASS\" for r in rows)\n", " print(f\"\\nPassed {passed}/{len(rows)}\")" ] }, { "cell_type": "code", "execution_count": 35, "id": "7dc20c84-44eb-443d-9fd1-852f347fa7ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test | Expect | Got | Status\n", "check_1 | True | True | PASS \n", "check_2 | True | True | PASS \n", "check_3 | True | True | PASS \n", "check_4 | True | True | PASS \n", "check_5 | True | True | PASS \n", "check_6 | False | False | PASS \n", "check_7 | False | False | PASS \n", "check_8 | True | True | PASS \n", "check_9 | False | False | PASS \n", "check_10 | True | True | PASS \n", "check_11 | True | True | PASS \n", "check_12 | False | False | PASS \n", "check_13 | True | True | PASS \n", "check_14 | True | True | PASS \n", "check_15 | True | True | PASS \n", "check_16 | True | True | PASS \n", "\n", "Passed 16/16\n" ] } ], "source": [ "run_demos_table(tests)" ] }, { "cell_type": "markdown", "id": "34224da7-26c5-4523-96e4-10eeb0623a64", "metadata": {}, "source": [ " \n", "## 3.8 Loading the evaluation dataset" ] }, { "cell_type": "markdown", "id": "b109d67c-5af4-4c01-a6c0-64444854a911", "metadata": {}, "source": [ "- The previous section implemented the basic evaluation pipeline\n", "- In this section, we load the dataset (step 7) to which we will apply this pipeline in order to evaluate the model (step 8, next section)." ] }, { "cell_type": "markdown", "id": "c2c21118-483d-4674-811c-b0ede83dd5b9", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "0d2dfa44-bc95-4629-8773-63fe40581e80", "metadata": {}, "source": [ "- The dataset was downloaded and prepared via the following code from the [HuggingFaceH4/MATH-500](https://huggingface.co/datasets/HuggingFaceH4/MATH-500) repository, which requires the [`datasets`](https://huggingface.co/docs/datasets/en/index) package depencency (you don't need to execute this, it's only included for reference):\n", "\n", "```python\n", "from datasets import load_dataset\n", "import json\n", "\n", "dset = load_dataset(\"HuggingFaceH4/MATH-500\", split=\"test\")\n", "\n", "math_data = dset.to_list()\n", "with open(\"math500_test.json\", \"w\", encoding=\"utf-8\") as f:\n", " json.dump(math_data, f, ensure_ascii=False, indent=2)\n", "```" ] }, { "cell_type": "code", "execution_count": 36, "id": "b3381aa3-7980-49aa-a0db-11fe0ef837f1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of entries: 500\n" ] } ], "source": [ "import json\n", "import requests\n", "\n", "def load_math500_test(local_path=\"math500_test.json\", save_copy=True):\n", " local_path = Path(local_path)\n", " url = (\n", " \"https://raw.githubusercontent.com/rasbt/reasoning-from-scratch/\"\n", " \"main/ch03/01_main-chapter-code/math500_test.json\"\n", " )\n", "\n", " if local_path.exists():\n", " with local_path.open(\"r\", encoding=\"utf-8\") as f:\n", " data = json.load(f)\n", " else:\n", " r = requests.get(url, timeout=30)\n", " r.raise_for_status()\n", " data = r.json()\n", "\n", " if save_copy: # Saves a local copy\n", " with local_path.open(\"w\", encoding=\"utf-8\") as f:\n", " json.dump(data, f, indent=2)\n", "\n", " return data\n", "\n", "math_data = load_math500_test()\n", "print(\"Number of entries:\", len(math_data))" ] }, { "cell_type": "code", "execution_count": 37, "id": "17fd27ff-792c-4450-9680-b673db46210e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'answer': '\\\\left( 3, \\\\frac{\\\\pi}{2} \\\\right)',\n", " 'level': 2,\n", " 'problem': 'Convert the point $(0,3)$ in rectangular coordinates to polar '\n", " 'coordinates. Enter your answer in the form $(r,\\\\theta),$ where '\n", " '$r > 0$ and $0 \\\\le \\\\theta < 2 \\\\pi.$',\n", " 'solution': 'We have that $r = \\\\sqrt{0^2 + 3^2} = 3.$ Also, if we draw the '\n", " 'line connecting the origin and $(0,3),$ this line makes an angle '\n", " 'of $\\\\frac{\\\\pi}{2}$ with the positive $x$-axis.\\n'\n", " '\\n'\n", " '[asy]\\n'\n", " 'unitsize(0.8 cm);\\n'\n", " '\\n'\n", " 'draw((-0.5,0)--(3.5,0));\\n'\n", " 'draw((0,-0.5)--(0,3.5));\\n'\n", " 'draw(arc((0,0),3,0,90),red,Arrow(6));\\n'\n", " '\\n'\n", " 'dot((0,3), red);\\n'\n", " 'label(\"$(0,3)$\", (0,3), W);\\n'\n", " 'dot((3,0), red);\\n'\n", " '[/asy]\\n'\n", " '\\n'\n", " 'Therefore, the polar coordinates are $\\\\boxed{\\\\left( 3, '\n", " '\\\\frac{\\\\pi}{2} \\\\right)}.$',\n", " 'subject': 'Precalculus',\n", " 'unique_id': 'test/precalculus/807.json'}\n" ] } ], "source": [ "from pprint import pprint\n", "pprint(math_data[0])" ] }, { "cell_type": "markdown", "id": "6229ba79-216f-4d04-9654-e1df2796b9f9", "metadata": {}, "source": [ " \n", "## 3.9 Evaluating the model" ] }, { "cell_type": "markdown", "id": "41dc5c3c-3287-4ff4-b8df-719b59355a45", "metadata": {}, "source": [ "- In the previous section, we loaded the dataset; now we can apply the evaluation pipeline to evaluate the model on this dataset (step 7)" ] }, { "cell_type": "markdown", "id": "5cdd1b75-9234-4241-bcd5-2238908ca1f0", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 38, "id": "a017826b-6a5d-4586-8fb3-5c8fbf754eef", "metadata": {}, "outputs": [], "source": [ "def render_prompt(prompt):\n", " template = (\n", " \"You are a helpful math assistant.\\n\"\n", " \"Answer the question and write the final result on a new line as:\\n\"\n", " \"\\\\boxed{ANSWER}\\n\\n\"\n", " f\"Question:\\n{prompt}\\n\\nAnswer:\"\n", " )\n", " return template" ] }, { "cell_type": "code", "execution_count": 39, "id": "9386b940-e6da-4c47-8bea-9bc40c31eebc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "You are a helpful math assistant.\n", "Answer the question and write the final result on a new line as:\n", "\\boxed{ANSWER}\n", "\n", "Question:\n", "If $a+b=3$ and $ab=\\tfrac{13}{6}$, what is the value of $a^2+b^2$?\n", "\n", "Answer:\n" ] } ], "source": [ "prompt = ( # Same prompt we used at the beginning of the chapter\n", " r\"If $a+b=3$ and $ab=\\tfrac{13}{6}$, \"\n", " r\"what is the value of $a^2+b^2$?\"\n", ")\n", "prompt_fmt = render_prompt(prompt)\n", "print(prompt_fmt)" ] }, { "cell_type": "code", "execution_count": 40, "id": "70d1cf2e-5a34-45b2-88dd-761c4fc11300", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " \\boxed{10}" ] } ], "source": [ "generated_text = generate_text_stream_concat(\n", " model, tokenizer, prompt_fmt, device,\n", " max_new_tokens=2048,\n", " verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": 41, "id": "34d82a23-cd19-4983-8593-795e569af49f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\ndef render_prompt(prompt):\\n template = (\\n \"You are a helpful math assistant.\\n\"\\n \"Solve the problem and write the final result on a new line as:\\n\"\\n \"\\\\boxed{ANSWER}\\n\\n\"\\n f\"Problem:\\n{prompt}\\n\\nAnswer:\"\\n )\\n return template\\n'" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Below is an alternative prompt template\n", "# which swaps \"Question\" with \"Problem\"\n", "\n", "\"\"\"\n", "def render_prompt(prompt):\n", " template = (\n", " \"You are a helpful math assistant.\\n\"\n", " \"Solve the problem and write the final result on a new line as:\\n\"\n", " \"\\\\boxed{ANSWER}\\n\\n\"\n", " f\"Problem:\\n{prompt}\\n\\nAnswer:\"\n", " )\n", " return template\n", "\"\"\"\n", "\n", "# This can noticeably affect the MATH-500 results:\n", "# Base model on mps: improves accuracy 20% -> 40%\n", "# Reasoning model on mps: worsens accuracy 90% -> 60%" ] }, { "cell_type": "code", "execution_count": 42, "id": "ff06715b-dff4-48fe-9880-aca24e19ec14", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\ndef render_prompt(prompt):\\n return prompt\\n'" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Alternatively, we may use no prompt template\n", "\n", "\"\"\"\n", "def render_prompt(prompt):\n", " return prompt\n", "\"\"\"\n", "\n", "# This can noticeably affect the MATH-500 results:\n", "# Base model on mps: improves accuracy 20% -> 70%\n", "# Reasoning model on mps: worsens accuracy 90% -> 50%" ] }, { "cell_type": "code", "execution_count": 43, "id": "9e84752c-a379-4c2e-b4e0-05ae2b98364f", "metadata": {}, "outputs": [], "source": [ "def mini_eval_demo(model, tokenizer, device):\n", " ex = { # Test example with \"problem\" and \"answer\" fields\n", " \"problem\": \"Compute 1/2 + 1/6.\",\n", " \"answer\": \"2/3\"\n", " }\n", " prompt = render_prompt(ex[\"problem\"]) # 1. Apply prompt template\n", " gen_text = generate_text_stream_concat( # 2. Generate response\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=64,\n", " )\n", " pred_answer = extract_final_candidate(gen_text) # 3. Extract and normalize answer\n", " is_correct = grade_answer( # 4. Grade answer\n", " pred_answer, ex[\"answer\"]\n", " )\n", " print(f\"Device: {device}\")\n", " print(f\"Prediction: {pred_answer}\")\n", " print(f\"Ground truth: {ex['answer']}\")\n", " print(f\"Correct: {is_correct}\")" ] }, { "cell_type": "code", "execution_count": 44, "id": "9d45fa99-1502-4484-bfd6-6421ba76b3cb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: mps\n", "Prediction: 1/3\n", "Ground truth: 2/3\n", "Correct: False\n" ] } ], "source": [ "mini_eval_demo(model, tokenizer, device)" ] }, { "cell_type": "code", "execution_count": 45, "id": "d0f91082-e578-4c82-a55d-a0bfcb82e151", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "\n", "# Helper function to calculate remaining time\n", "def eta_progress_message(\n", " processed,\n", " total,\n", " start_time,\n", " show_eta=False,\n", " label=\"Progress\",\n", "):\n", " progress = f\"{label}: {processed}/{total}\"\n", " pad_width = len(f\"{label}: {total}/{total} | ETA: 00h 00m 00s\")\n", " if not show_eta or processed <= 0:\n", " return progress.ljust(pad_width)\n", "\n", " elapsed = time.time() - start_time\n", " if elapsed <= 0:\n", " return progress.ljust(pad_width)\n", "\n", " remaining = max(total - processed, 0)\n", "\n", " if processed:\n", " avg_time = elapsed / processed\n", " eta_seconds = avg_time * remaining\n", " else:\n", " eta_seconds = 0\n", "\n", " eta_seconds = max(int(round(eta_seconds)), 0)\n", " minutes, rem_seconds = divmod(eta_seconds, 60)\n", " hours, minutes = divmod(minutes, 60)\n", " if hours:\n", " eta = f\"{hours}h {minutes:02d}m {rem_seconds:02d}s\"\n", " elif minutes:\n", " eta = f\"{minutes:02d}m {rem_seconds:02d}s\"\n", " else:\n", " eta = f\"{rem_seconds:02d}s\"\n", "\n", " message = f\"{progress} | ETA: {eta}\"\n", " return message.ljust(pad_width)\n", "\n", "def evaluate_math500_stream(\n", " model,\n", " tokenizer,\n", " device,\n", " math_data,\n", " out_path=None,\n", " max_new_tokens=512,\n", " verbose=False,\n", "):\n", "\n", " if out_path is None:\n", " dev_name = str(device).replace(\":\", \"-\") # Make filename compatible with Windows\n", " out_path = Path(f\"math500-{dev_name}.jsonl\")\n", "\n", " num_examples = len(math_data)\n", " num_correct = 0\n", " start_time = time.time()\n", "\n", " with open(out_path, \"w\", encoding=\"utf-8\") as f: # Save results for inspection\n", " for i, row in enumerate(math_data, start=1):\n", " prompt = render_prompt(row[\"problem\"]) # 1. Apply prompt template\n", " gen_text = generate_text_stream_concat( # 2. Generate response\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=max_new_tokens,\n", " verbose=verbose,\n", " )\n", "\n", " extracted = extract_final_candidate( # 3. Extract and normalize answer\n", " gen_text\n", " )\n", " is_correct = grade_answer( # 4. Grade answer\n", " extracted, row[\"answer\"]\n", " )\n", " num_correct += int(is_correct)\n", "\n", " record = { # Record to be saved for inspection\n", " \"index\": i,\n", " \"problem\": row[\"problem\"],\n", " \"gtruth_answer\": row[\"answer\"],\n", " \"generated_text\": gen_text,\n", " \"extracted\": extracted,\n", " \"correct\": bool(is_correct),\n", " }\n", " f.write(json.dumps(record, ensure_ascii=False) + \"\\n\")\n", "\n", " progress_msg = eta_progress_message(\n", " processed=i,\n", " total=num_examples,\n", " start_time=start_time,\n", " show_eta=True,\n", " label=\"MATH-500\",\n", " )\n", " print(progress_msg, end=\"\\r\", flush=True)\n", " if verbose: # Print responses during the generation process\n", " print(\n", " f\"\\n\\n{'='*50}\\n{progress_msg}\\n\"\n", " f\"{'='*50}\\nExtracted: {extracted}\\n\"\n", " f\"Expected: {row['answer']}\\n\"\n", " f\"Correct so far: {num_correct}\\n{'-'*50}\"\n", " )\n", "\n", " # Print summary information\n", " seconds_elapsed = time.time() - start_time\n", " acc = num_correct / num_examples if num_examples else 0.0\n", " print(f\"\\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})\")\n", " print(f\"Total time: {seconds_elapsed/60:.1f} min\")\n", " print(f\"Logs written to: {out_path}\")\n", " return num_correct, num_examples, acc" ] }, { "cell_type": "markdown", "id": "f5fa386b-de7d-46e1-b25a-5fec740421e8", "metadata": {}, "source": [ "- We only evaluate on 10 examples for demo purposes (to keep the runtime reasonable)" ] }, { "cell_type": "code", "execution_count": 46, "id": "42befbdf-feb9-453c-acb5-910c7c981ea3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: base\n", "Device: mps\n", "MATH-500: 10/10 | ETA: 00s\n", "Accuracy: 30.0% (3/10)\n", "Total time: 0.4 min\n", "Logs written to: math500-mps.jsonl\n" ] } ], "source": [ "print(\"Model:\", WHICH_MODEL)\n", "print(\"Device:\", device)\n", "num_correct, num_examples, acc = evaluate_math500_stream(\n", " model, tokenizer, device, \n", " math_data=math_data[:10],\n", " max_new_tokens=2048,\n", " verbose=False\n", ")" ] }, { "cell_type": "markdown", "id": "408afa29-fadf-4af1-adbf-74e4f87c3aba", "metadata": {}, "source": [ "| Mode | Device | Accuracy | MATH-500 size | Time |\n", "|-----------|--------|----------|---------------|-----------------------|\n", "| Base | CPU | 30% | 10 | 0.7 min (Mac Mini M4) |\n", "| Base | MPS | 20% | 10 | 0.4 min (Mac Mini M4) |\n", "| Base | CUDA | 30% | 10 | 0.2 min (DGX Spark) |\n", "| Base | XPU | 30% | 10 | 1.2 min (Intel) |\n", "| Reasoning | CPU | 90% | 10 | 9.5 min (Mac Mini M4) |\n", "| Reasoning | MPS | 80% | 10 | 3.8 min (Mac Mini M4) |\n", "| Reasoning | CUDA | 90% | 10 | 3.7 min (DGX Spark) |\n", "| Reasoning | XPU | 70% | 10 | 8.5 min (Intel) |\n", "\n", "\n", "| Mode | Device | Accuracy | MATH-500 size | Time |\n", "|-----------|--------|----------|------------------|------------------------|\n", "| Base | CUDA | 15.6% | 500 | 10.0 min (DGX Spark) |\n", "| Reasoning | CUDA | 50.8% | 500 | 182.2 min (DGX Spark) |\n", "\n", "- Note that these values were obtained in PyTorch 2.8 and can differ in different versions of PyTorch" ] }, { "cell_type": "markdown", "id": "9d69cd3e-202b-4537-bd80-312b160ec904", "metadata": {}, "source": [ "- For reference, above are the different accuracy values \n", "- Note that \"GPU\" here refers to a NVIDIA (\"cuda\") GPU; MPS refers to an Apple Silicon M4 chip\n", "- The reasoning model is much slower because it produces much longer responses\n", "- While Qwen3-Base is a pre-trained base model and the Qwen3 recommends using it without chat template, changing `tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)` to `tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path, apply_chat_template=True)` boosts the MATH-500 performance substantially (80%); note that it is not clear whether the MATH-500 test set was part of the training data; in the age of LLMs, we can assume that any data available on the internet has been part of the training data (also see the discussion [here](https://github.com/rasbt/LLMs-from-scratch/pull/828#issuecomment-3324829736))\n", "- The run for the 500 MATH-500 examples corresponds to changing the code here in the `evaluate_math500_stream` function call from `math_data=math_data[:10],` to `math_data=math_data,`\n", "- The bonus materials contain a script to run the evaluation batched mode for higher throughput (see [../02_math500-verifier-scripts/README.md](../02_math500-verifier-scripts/README.md); on an H100, with a batch size of 128, the base model can be evaluated in 3.3 min, and the reasoning model can be evaluated in 14.6 min" ] }, { "cell_type": "markdown", "id": "98a2fb3d-0c5a-4edf-8a2a-9227e3acc34c", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "32ae09ba-9989-4a76-a3c3-72ca792c2637", "metadata": {}, "source": [ "- For convenience, you can use the [../02_math500-verifier-scripts/evaluate_math500.py](../02_math500-verifier-scripts/evaluate_math500.py) script, which runs the MATH-500 evaluation code as a standalone script from the command line (see the [../02_math500-verifier-scripts/README.md](../02_math500-verifier-scripts/README.md) for more usage information)\n", "- The [../02_math500-verifier-scripts/evaluate_math500_batched.py](../02_math500-verifier-scripts/evaluate_math500_batched.py) script runs the code in this chapter in batched mode\n", " - This means it processes multiple examples per forward pass to accelerate the evaluation while requiring more RAM\n", " - With a batch size of 128, this reduces the runtime of the base model, when evaluating all 500 samples, from 13.3 min to 3.3 min on an H100 GPU\n", " - Similarly, it reduces the runtime of the reasoning model from 185.4 min to 14.6 min for the 500 examples in the dataset\n", " - Note that the H100 is used as an example, and the script is compatible with other GPUs (or CPUs) as well" ] }, { "cell_type": "markdown", "id": "fdc25545-9cac-40a8-b08f-fd42701e332a", "metadata": {}, "source": [ " \n", "## Summary" ] }, { "cell_type": "markdown", "id": "7b66ffc8-3f44-4784-9cd3-4772dee6333d", "metadata": {}, "source": [ "- No code in this section" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }