{ "cells": [ { "cell_type": "markdown", "id": "83efb6df-7d99-4fee-99f3-f2f668292110", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Reasoning Model (From Scratch) book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/reasoning-from-scratch\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "ef2ac59f-0dc1-4c3e-bb8c-2ea79e0f6657", "metadata": {}, "source": [ "# Chapter 4: Exercise Solutions" ] }, { "cell_type": "markdown", "id": "4735f8bb-dd7f-4a4f-8761-269f26b38349", "metadata": {}, "source": [ "Packages that are being used in this notebook:" ] }, { "cell_type": "code", "execution_count": 1, "id": "00e26411-6a34-4c89-bc24-2e36dd14c8eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reasoning_from_scratch version: 0.1.9\n", "torch version: 2.9.0\n", "tokenizers version: 0.21.4\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "used_libraries = [\n", " \"reasoning_from_scratch\",\n", " \"torch\",\n", " \"tokenizers\" # Used by reasoning_from_scratch\n", "]\n", "\n", "for lib in used_libraries:\n", " print(f\"{lib} version: {version(lib)}\")" ] }, { "cell_type": "markdown", "id": "8d101721-6848-4871-826a-eaf194ddb26a", "metadata": {}, "source": [ " \n", "## Exercise 4.1: Use chain-of-thought prompting on MATH-500" ] }, { "cell_type": "markdown", "id": "5d9257c6-384b-46a0-9767-c2f3db7dbcf0", "metadata": {}, "source": [ "- The modification just requires adding a prompt suffix, for example, `\"\\n\\nExplain step by step.\"` after applying the prompt template\n", "- The modified MATH-500 evaluation function from chapter 3 is shown below (the changes are commented via `# NEW`)" ] }, { "cell_type": "markdown", "id": "733ded33-7ef8-4214-bb71-4b0c206d6867", "metadata": {}, "source": [ "```python\n", "import json\n", "from pathlib import Path\n", "import time\n", "\n", "from reasoning_from_scratch.ch03 import (\n", " eta_progress_message,\n", " extract_final_candidate,\n", " render_prompt,\n", " grade_answer,\n", " generate_text_stream_concat,\n", ")\n", "\n", "\n", "def evaluate_math500_stream(\n", " model,\n", " tokenizer,\n", " device,\n", " math_data,\n", " out_path=None,\n", " max_new_tokens=512,\n", " verbose=False,\n", " prompt_suffix=\"\" # NEW\n", "):\n", "\n", " if out_path is None:\n", " dev_name = str(device).replace(\":\", \"-\")\n", " out_path = Path(f\"math500-{dev_name}.jsonl\")\n", "\n", " num_examples = len(math_data)\n", " num_correct = 0\n", " start_time = time.time()\n", "\n", " with open(out_path, \"w\", encoding=\"utf-8\") as f:\n", " for i, row in enumerate(math_data, start=1):\n", " prompt = render_prompt(row[\"problem\"])\n", " prompt += prompt_suffix # NEW\n", " gen_text = generate_text_stream_concat(\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=max_new_tokens,\n", " verbose=verbose,\n", " )\n", "\n", " extracted = extract_final_candidate(\n", " gen_text\n", " )\n", " is_correct = grade_answer(\n", " extracted, row[\"answer\"]\n", " )\n", " num_correct += int(is_correct)\n", "\n", " record = {\n", " \"index\": i,\n", " \"problem\": row[\"problem\"],\n", " \"gtruth_answer\": row[\"answer\"],\n", " \"generated_text\": gen_text,\n", " \"extracted\": extracted,\n", " \"correct\": bool(is_correct),\n", " }\n", " f.write(json.dumps(record, ensure_ascii=False) + \"\\n\")\n", "\n", " progress_msg = eta_progress_message(\n", " processed=i,\n", " total=num_examples,\n", " start_time=start_time,\n", " show_eta=True,\n", " label=\"MATH-500\",\n", " )\n", " print(progress_msg, end=\"\\r\", flush=True)\n", " if verbose:\n", " print(\n", " f\"\\n\\n{'='*50}\\n{progress_msg}\\n\"\n", " f\"{'='*50}\\nExtracted: {extracted}\\n\"\n", " f\"Expected: {row['answer']}\\n\"\n", " f\"Correct so far: {num_correct}\\n{'-'*50}\"\n", " )\n", "\n", " seconds_elapsed = time.time() - start_time\n", " acc = num_correct / num_examples if num_examples else 0.0\n", " print(f\"\\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})\")\n", " print(f\"Total time: {seconds_elapsed/60:.1f} min\")\n", " print(f\"Logs written to: {out_path}\")\n", " return num_correct, num_examples, acc\n", "```" ] }, { "cell_type": "markdown", "id": "0c760503-b57a-4947-b7bc-63620ebe2af9", "metadata": {}, "source": [ "- The improvements over the baseline in chapter 3 are shown below\n", "\n", "| | Method | Model | Accuracy | Time |\n", "|----|----------------------------------------------|-----------|----------|------------|\n", "| 1 | Baseline (chapter 3), greedy decoding | Base | 15.2% | 10.1 min |\n", "| 2 | Baseline (chapter 3), greedy decoding | Reasoning | 48.2% | 182.1 min |\n", "| 3 | Chain-of-thought prompting (\"CoT\") | Base | 40.6% | 84.5 min |" ] }, { "cell_type": "markdown", "id": "4971cdbf-2412-421e-b555-dbe2037b2c73", "metadata": {}, "source": [ "- For your convenience, you can run the [cot_prompting_math500.py](../02_math500-inference-scaling-scripts/cot_prompting_math500.py) script located in [../02_math500-inference-scaling-scripts](../02_math500-inference-scaling-scripts)" ] }, { "cell_type": "markdown", "id": "5bc62f87-3d9d-47cd-9eed-6e15982e478c", "metadata": {}, "source": [ " \n", "## Exercise 4.2: Use temperature scaling and top-p filtering on MATH-500 " ] }, { "cell_type": "markdown", "id": "35b6eea4-7b54-42d5-91f2-777a6b73c2af", "metadata": {}, "source": [ "- Here, we have to swap the `generate_text_stream_concat` with the `generate_text_stream_concat_flex` function and plug the `generate_text_top_p_stream_cache` function into it\n", "- - The modified MATH-500 evaluation function from chapter 3 is shown below (the changes are commented via `# NEW`)" ] }, { "cell_type": "markdown", "id": "c59d8915-5d4b-48f2-85ff-b169622a35e6", "metadata": {}, "source": [ "```python\n", "import json\n", "from pathlib import Path\n", "import time\n", "\n", "from reasoning_from_scratch.ch03 import (\n", " eta_progress_message,\n", " extract_final_candidate,\n", " render_prompt,\n", " grade_answer,\n", " generate_text_stream_concat,\n", ")\n", "from reasoning_from_scratch.ch04 import generate_text_stream_concat_flex\n", "\n", "\n", "def evaluate_math500_stream(\n", " model,\n", " tokenizer,\n", " device,\n", " math_data,\n", " out_path=None,\n", " max_new_tokens=512,\n", " verbose=False,\n", " temperature=1.0, # NEW\n", " top_p=1.0, # NEW\n", "):\n", "\n", " if out_path is None:\n", " dev_name = str(device).replace(\":\", \"-\")\n", " out_path = Path(f\"math500-{dev_name}.jsonl\")\n", "\n", " num_examples = len(math_data)\n", " num_correct = 0\n", " start_time = time.time()\n", "\n", " with open(out_path, \"w\", encoding=\"utf-8\") as f:\n", " for i, row in enumerate(math_data, start=1):\n", " prompt = render_prompt(row[\"problem\"])\n", " gen_text = generate_text_stream_concat_flex( # NEW\n", " model, tokenizer, prompt, device,\n", " max_new_tokens=max_new_tokens,\n", " verbose=verbose,\n", " generate_func=generate_text_top_p_stream_cache, # NEW\n", " temperature=temperature, # NEW\n", " top_p=top_p # NEW\n", " )\n", "\n", " extracted = extract_final_candidate(\n", " gen_text\n", " )\n", " is_correct = grade_answer(\n", " extracted, row[\"answer\"]\n", " )\n", " num_correct += int(is_correct)\n", "\n", " record = {\n", " \"index\": i,\n", " \"problem\": row[\"problem\"],\n", " \"gtruth_answer\": row[\"answer\"],\n", " \"generated_text\": gen_text,\n", " \"extracted\": extracted,\n", " \"correct\": bool(is_correct),\n", " }\n", " f.write(json.dumps(record, ensure_ascii=False) + \"\\n\")\n", "\n", " progress_msg = eta_progress_message(\n", " processed=i,\n", " total=num_examples,\n", " start_time=start_time,\n", " show_eta=True,\n", " label=\"MATH-500\",\n", " )\n", " print(progress_msg, end=\"\\r\", flush=True)\n", " if verbose:\n", " print(\n", " f\"\\n\\n{'='*50}\\n{progress_msg}\\n\"\n", " f\"{'='*50}\\nExtracted: {extracted}\\n\"\n", " f\"Expected: {row['answer']}\\n\"\n", " f\"Correct so far: {num_correct}\\n{'-'*50}\"\n", " )\n", "\n", " seconds_elapsed = time.time() - start_time\n", " acc = num_correct / num_examples if num_examples else 0.0\n", " print(f\"\\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})\")\n", " print(f\"Total time: {seconds_elapsed/60:.1f} min\")\n", " print(f\"Logs written to: {out_path}\")\n", " return num_correct, num_examples, acc\n", "```" ] }, { "cell_type": "markdown", "id": "c361ae56-cd57-40a1-8a95-af74783e8f3c", "metadata": {}, "source": [ "- When running the method with `temperature` 0.9 and `top_p` 0.9, there is only a minor difference compared to the baseline (row 1) in the table below; however, that's expected though as this is merely a setup for self-consistency sampling" ] }, { "cell_type": "markdown", "id": "b0727661-03bc-4cd0-8bfe-947c7e5b80c1", "metadata": {}, "source": [ "| | Method | Model | Accuracy | Time |\n", "| ---- | ----------------------------------------- | --------- | -------- | --------- |\n", "| 1 | Baseline (chapter 3), greedy decoding | Base | 15.2% | 10.1 min |\n", "| ... | ... | ... | ... | ... |\n", "| 4 | Temperature and top-p (\"Top-p\") | Base | 17.8% | 30.7 min |" ] }, { "cell_type": "markdown", "id": "7532db22-c5e8-4ce9-bef3-5061976d3268", "metadata": {}, "source": [ "- For your convenience, you can run the [self_consistency_math500.py](../02_math500-inference-scaling-scripts/self_consistency_math500.py) script located in [../02_math500-inference-scaling-scripts](../02_math500-inference-scaling-scripts)\n", "- Technically, it's a self-consistency sampling script, but if we set `--num_samples 1`, it effectively disables the self-consistency sampling portion" ] }, { "cell_type": "markdown", "id": "bc6278a1-9c6c-45a9-9812-ce0c593bbf35", "metadata": {}, "source": [ " \n", "## Exercise 4.3: Use self-consistency sampling on MATH-500" ] }, { "cell_type": "markdown", "id": "0e430c19-6568-4498-a198-99567543d1f4", "metadata": {}, "source": [ "- Taking the `evaluate_math500_stream` function from chapter 3 as a basis, the first change is to swap out the ` gen_text = generate_text_stream_concat(...)` portion with the `results = self_consistency_vote(...)` call from chapter 4\n", "- The second change involves implementing the simple tie-breaking rule, where the code takes the first instance of the most frequent group (e.g., if we have the results 1, 3, 5, 3, 5, then it would return 3 as the answer\n", "- So, since the most frequent groups are recorded under `results[\"majority_winners\"]`, one approach to break ties is to get the first instance of `results[\"majority_winners\"]`, i.e., `results[\"majority_winners\"][0]`" ] }, { "cell_type": "markdown", "id": "0550e5ae-2739-4535-ad8f-e41d66194215", "metadata": {}, "source": [ "```python\n", "import json\n", "from pathlib import Path\n", "import time\n", "\n", "from reasoning_from_scratch.ch03 import (\n", " eta_progress_message,\n", " render_prompt,\n", " grade_answer,\n", ")\n", "from reasoning_from_scratch.ch04 import self_consistency_vote\n", "\n", "\n", "def evaluate_math500_stream(\n", " model,\n", " tokenizer,\n", " device,\n", " math_data,\n", " out_path=None,\n", " max_new_tokens=2048,\n", " verbose=False,\n", " prompt_suffix=\"\", # NEW\n", " temperature=1.0, # NEW\n", " top_p=1.0, # NEW\n", " seed=None, # NEW\n", " num_samples=10, # NEW\n", "):\n", "\n", " if out_path is None:\n", " dev_name = str(device).replace(\":\", \"-\")\n", " out_path = Path(f\"math500-{dev_name}.jsonl\")\n", "\n", " num_examples = len(math_data)\n", " num_correct = 0\n", " start_time = time.time()\n", "\n", " with open(out_path, \"w\", encoding=\"utf-8\") as f:\n", " for i, row in enumerate(math_data, start=1):\n", " prompt = render_prompt(row[\"problem\"])\n", "\n", " ##############################################################\n", " # NEW\n", " prompt += prompt_suffix\n", " results = self_consistency_vote(\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompt=prompt,\n", " device=device,\n", " num_samples=num_samples,\n", " temperature=temperature,\n", " top_p=top_p,\n", " max_new_tokens=max_new_tokens,\n", " show_progress=False,\n", " show_long_answer=False,\n", " seed=seed,\n", " )\n", "\n", " # resolve ties\n", " if results[\"final_answer\"] is None:\n", " extracted = results[\"majority_winners\"][0]\n", " else:\n", " extracted = results[\"final_answer\"]\n", "\n", " # extracted = extract_final_candidate(\n", " # gen_text\n", " # )\n", "\n", " # Optionally, get long answer\n", " if extracted is not None:\n", " for idx, s in enumerate(results[\"short_answers\"]):\n", " if s == extracted:\n", " long_answer = results[\"full_answers\"][idx]\n", " break\n", " gen_text = long_answer\n", " ##############################################################\n", "\n", " is_correct = grade_answer(\n", " extracted, row[\"answer\"]\n", " )\n", " num_correct += int(is_correct)\n", "\n", " record = {\n", " \"index\": i,\n", " \"problem\": row[\"problem\"],\n", " \"gtruth_answer\": row[\"answer\"],\n", " \"generated_text\": gen_text,\n", " \"extracted\": extracted,\n", " \"correct\": bool(is_correct),\n", " }\n", " f.write(json.dumps(record, ensure_ascii=False) + \"\\n\")\n", "\n", " progress_msg = eta_progress_message(\n", " processed=i,\n", " total=num_examples,\n", " start_time=start_time,\n", " show_eta=True,\n", " label=\"MATH-500\",\n", " )\n", " print(progress_msg, end=\"\\r\", flush=True)\n", " if verbose:\n", " print(\n", " f\"\\n\\n{'='*50}\\n{progress_msg}\\n\"\n", " f\"{'='*50}\\nExtracted: {extracted}\\n\"\n", " f\"Expected: {row['answer']}\\n\"\n", " f\"Correct so far: {num_correct}\\n{'-'*50}\"\n", " )\n", "\n", " seconds_elapsed = time.time() - start_time\n", " acc = num_correct / num_examples if num_examples else 0.0\n", " print(f\"\\nAccuracy: {acc*100:.1f}% ({num_correct}/{num_examples})\")\n", " print(f\"Total time: {seconds_elapsed/60:.1f} min\")\n", " print(f\"Logs written to: {out_path}\")\n", " return num_correct, num_examples, acc\n", "```" ] }, { "cell_type": "markdown", "id": "4be37812-5f5c-4132-b8c6-d160dc9a9720", "metadata": {}, "source": [ "- The performance improvements when using self-consistency sampling are summarized in the table below (rows 5-7 and rows 9-12)" ] }, { "cell_type": "markdown", "id": "6fa71a13-0f76-4a8b-aa82-d3ec85122674", "metadata": {}, "source": [ "| | Method | Model | Accuracy | Time |\n", "| ---- | ----------------------------------------- | --------- | -------- | --------- |\n", "| 1 | Baseline (chapter 3), greedy decoding | Base | 15.2% | 10.1 min |\n", "| 2 | Baseline (chapter 3), greedy decoding | Reasoning | 48.2% | 182.1 min |\n", "| 3 | Chain-of-thought prompting (\"CoT\") | Base | 40.6% | 84.5 min |\n", "| 4 | Temperature and top-p (\"Top-p\") | Base | 17.8% | 30.7 min |\n", "| 5 | \"Top-p\" + Self-consistency (n=3) | Base | 29.6% | 97.6 min |\n", "| 6 | \"Top-p\" + Self-consistency (n=5) | Base | 27.8% | 116.8 min |\n", "| 7 | \"Top-p\" + Self-consistency (n=10) | Base | 31.6% | 300.4 min |\n", "| 8 | \"Top-p\" + \"CoT\" | Base | 33.4% | 129.2 min |\n", "| 9 | Self-consistency (n=3) + \"Top-p\" + \"CoT\" | Base | 42.2% | 211.6 min |\n", "| 10 | Self-consistency (n=5) + \"Top-p\" + \"CoT\" | Base | 48.0% | 452.9 min |\n", "| 11 | Self-consistency (n=10) + \"Top-p\" + \"CoT\" | Base | 52.0% | 862.6 min |\n", "| 12 | Self-consistency (n=3) + \"Top-p\" + \"CoT\" | Reasoning | 55.2% | 544.4 min |" ] }, { "cell_type": "markdown", "id": "27ee9bd2-853b-431b-b25c-748bdf88877d", "metadata": {}, "source": [ "- For your convenience, you can run the [self_consistency_math500.py](../02_math500-inference-scaling-scripts/self_consistency_math500.py) script located in [../02_math500-inference-scaling-scripts](../02_math500-inference-scaling-scripts) to reproduce these; the [../02_math500-inference-scaling-scripts](../02_math500-inference-scaling-scripts) contains further information on which settings to use" ] }, { "cell_type": "markdown", "id": "c985e7ed-bca4-42b0-bf81-aa0eb0b0f3a7", "metadata": {}, "source": [ " \n", "## Exercise 4.4: Early stopping in self-consistency sampling" ] }, { "cell_type": "markdown", "id": "edec5192-e333-4972-8816-9774c4c33485", "metadata": {}, "source": [ "- The early stopping check can be implemented by adding a few lines of code that check whether the given answer is already counted multiple times, or, more specifically, if the given answer count is greater than num_samples / 2:\n", "\n", "```python\n", "if early_stop and counts[short] > num_samples / 2:\n", " majority_winners = [short]\n", " final_answer = short\n", " break\n", "```\n", "\n", "- The complete, modified function is shown below, with the changes highlighted via `# New`" ] }, { "cell_type": "markdown", "id": "3633e475-4c49-4e6a-a84f-1d35d26c71c7", "metadata": {}, "source": [ "```python\n", "import torch\n", "from collections import Counter\n", "\n", "from reasoning_from_scratch.ch03 import (\n", " extract_final_candidate,\n", ")\n", "from reasoning_from_scratch.ch04 import (\n", " generate_text_stream_concat_flex,\n", " generate_text_top_p_stream_cache,\n", ")\n", "\n", "\n", "def self_consistency_vote(\n", " model,\n", " tokenizer,\n", " prompt,\n", " device,\n", " num_samples=10,\n", " temperature=0.8,\n", " top_p=0.9,\n", " max_new_tokens=2048,\n", " show_progress=True,\n", " show_long_answer=False,\n", " seed=None,\n", " early_stop=True, # NEW\n", "):\n", " full_answers, short_answers = [], []\n", " counts = Counter()\n", " groups = {}\n", " majority_winners, final_answer = [], None\n", "\n", " for i in range(num_samples):\n", " if seed is not None:\n", " torch.manual_seed(seed + i + 1)\n", "\n", " answer = generate_text_stream_concat_flex(\n", " model=model,\n", " tokenizer=tokenizer,\n", " prompt=prompt,\n", " device=device,\n", " max_new_tokens=max_new_tokens,\n", " verbose=show_long_answer,\n", " generate_func=generate_text_top_p_stream_cache,\n", " temperature=temperature,\n", " top_p=top_p,\n", " )\n", "\n", " short = extract_final_candidate(\n", " answer, fallback=\"number_then_full\"\n", " )\n", " full_answers.append(answer)\n", " short_answers.append(short)\n", " counts[short] += 1\n", " groups.setdefault(short, []).append(i)\n", "\n", " if show_progress:\n", " print(f\"[Sample {i+1}/{num_samples}] → {short!r}\")\n", "\n", " #########################################################\n", " # NEW\n", " # Early stop if one answer already meets >= 50% majority\n", " if early_stop and counts[short] > num_samples / 2:\n", " majority_winners = [short]\n", " final_answer = short\n", " break\n", " #########################################################\n", "\n", " if final_answer is None:\n", " mc = counts.most_common()\n", " if mc:\n", " top_freq = mc[0][1]\n", " majority_winners = [s for s, f in mc if f == top_freq]\n", " final_answer = mc[0][0] if len(majority_winners) == 1 else None\n", "\n", " return {\n", " \"full_answers\": full_answers,\n", " \"short_answers\": short_answers,\n", " \"counts\": dict(counts),\n", " \"groups\": groups,\n", " \"majority_winners\": majority_winners,\n", " \"final_answer\": final_answer,\n", " }\n", "```" ] }, { "cell_type": "markdown", "id": "6c636069-817f-4e83-a60c-9cb398ae70cf", "metadata": {}, "source": [ "- For your convenience, you can run the [self_consistency_math500.py](../02_math500-inference-scaling-scripts/self_consistency_math500.py) script located in [../02_math500-inference-scaling-scripts](../02_math500-inference-scaling-scripts) with the `--early_stop` flag to use this modified function on the MATH-500 dataset" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }