{ "cells": [ { "cell_type": "markdown", "id": "136a4efe-fb99-4311-8679-e0a5b6282755", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
汉化的库: https://github.com/GoatCsu/CN-LLMs-from-scratch.git\n", "
\n", "
\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "b1910a06-e8a3-40ac-8201-ff70615b1ba4", "metadata": { "tags": [] }, "source": [ "# 使用 GPT-4 进行指令数据优化:基于 Reflection-Tuning 方法 \n" ] }, { "cell_type": "markdown", "id": "a128651b-f326-4232-a994-42f38b7ed520", "metadata": {}, "source": [ "- 本笔记本使用 **OpenAI 的 GPT-4 API** 来实现 **[Reflection-Tuning: Data Recycling Improves LLM Instruction-Tuning](https://arxiv.org/abs/2310.11716)** 论文中的 **数据优化方法**。\n", "\n", "![](https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/reflection-tuning/reflection-tuning.webp)\n", "\n", "- 在 **原始论文** 中,研究人员优化了 **[Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca)** 和 **[WizardLM](https://huggingface.co/datasets/WizardLMTeam/WizardLM_evol_instruct_70k)** 指令微调数据集。 \n", "- 在本笔记本中,我们优化的是 **[第 7 章使用的指令数据集](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch07/01_main-chapter-code/instruction-data.json)**。 \n", " - **由于其格式与 Alpaca 数据集相同**,本代码同样适用于 **Alpaca 数据集**。 \n", "\n", "- **期望的数据集格式如下**:\n", "\n", "\n", "```python\n", " {\n", " \"instruction\": \"Edit the following sentence for grammar.\",\n", " \"input\": \"He go to the park every day.\",\n", " \"output\": \"He goes to the park every day.\"\n", " },\n", " {\n", " \"instruction\": \"Convert 45 kilometers to meters.\",\n", " \"input\": \"\",\n", " \"output\": \"45 kilometers is 45000 meters.\"\n", " },\n", "```" ] }, { "cell_type": "markdown", "id": "86ac82b4-e3d5-4ed5-8f46-6c97a9313463", "metadata": {}, "source": [ "> **请注意**:本笔记本复现了论文中的方法,即使用 **GPT API** 来优化现有数据集。然而,需要注意的是,根据 **[OpenAI 使用条款](https://openai.com/policies/row-terms-of-use/)**,**禁止使用 GPT API 生成的数据来开发与 OpenAI 竞争的模型**。 \n", "> **具体条款**: \n", "> *“What you cannot do... Use Output to develop models that compete with OpenAI.”* \n", "> **相关讨论** 可参考 [Reddit 论坛](https://www.reddit.com/r/LocalLLaMA/comments/17vbg1f/does_openai_tos_prohibit_generating_datasets_for/))。 " ] }, { "cell_type": "code", "execution_count": 1, "id": "267ba0d1-b884-42df-85bd-0be746fd47a5", "metadata": {}, "outputs": [], "source": [ "# pip install -r requirements-extra.txt" ] }, { "cell_type": "code", "execution_count": 2, "id": "63610acc-db94-437f-8d38-e99dca0299cb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "openai version: 1.30.3\n", "tqdm version: 4.66.4\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "pkgs = [\n", " \"openai\", # OpenAI API\n", " \"tqdm\", # 进度条\n", "]\n", "\n", "for p in pkgs:\n", " print(f\"{p} version: {version(p)}\")" ] }, { "cell_type": "markdown", "id": "8bcdcb34-ac75-4f4f-9505-3ce0666c42d5", "metadata": {}, "source": [ "## 测试OpenAI API" ] }, { "cell_type": "markdown", "id": "9558a522-650d-401a-84fc-9fd7b1f39da7", "metadata": {}, "source": [ "- 首先,让我们测试 **OpenAI API 是否正确配置**。 \n", "- **如果您还没有账户**,请前往 [OpenAI 平台](https://platform.openai.com/) **注册**。 \n", "- **请注意**,您需要 **向账户充值**,因为 **GPT-4 API 不是免费的**(详见 [计费页面](https://platform.openai.com/settings/organization/billing/overview))。 \n", "\n", "- **首先,我们需要提供 OpenAI API 密钥**,该密钥可在 [API Keys 页面](https://platform.openai.com/api-keys) 获取。 \n", "- **请勿与他人分享您的 API 密钥**。 \n", "- **将 API 密钥(`\"sk-...\"`)添加到本文件夹中的 `config.json` 文件中**。 " ] }, { "cell_type": "code", "execution_count": 3, "id": "65b0ba76-1fb1-4306-a7c2-8f3bb637ccdb", "metadata": {}, "outputs": [], "source": [ "import json\n", "from openai import OpenAI\n", "\n", "# 从JSON文件中加载API密钥\n", "# 请确保将\"sk-...\"替换为您在https://platform.openai.com/api-keys上的实际API密钥\n", "with open(\"config.json\", \"r\") as config_file:\n", " config = json.load(config_file)\n", " api_key = config[\"OPENAI_API_KEY\"]\n", "\n", "client = OpenAI(api_key=api_key)" ] }, { "cell_type": "markdown", "id": "16642a48-1cab-40d2-af08-ab8c2fbf5876", "metadata": {}, "source": [ "- 首先,我们使用一个 **简单示例** 进行 API 调用,以确保其 **正常运行**: " ] }, { "cell_type": "code", "execution_count": 4, "id": "08e9ef2e-e816-4283-840e-43625791ad33", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'hello world'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def run_chatgpt(prompt, client, model=\"gpt-4o-mini\", system_prompt=None):\n", " # 如果提供了系统提示,则定义系统消息\n", " messages = []\n", " \n", " if system_prompt:\n", " messages.append({\"role\": \"system\", \"content\": system_prompt})\n", " \n", " # 将用户提示添加到消息中\n", " messages.append({\"role\": \"user\", \"content\": prompt})\n", "\n", " # 调用API\n", " response = client.chat.completions.create(\n", " model=model,\n", " messages=messages,\n", " temperature=0.0,\n", " seed=123,\n", " )\n", " \n", " # 返回模型的响应\n", " return response.choices[0].message.content\n", "\n", "\n", "prompt = f\"Respond with 'hello world' if you got this message.\"\n", "run_chatgpt(prompt, client)" ] }, { "cell_type": "markdown", "id": "162a4739-6f03-4092-a5c2-f57a0b6a4c4d", "metadata": {}, "source": [ "## 加载 JSON 数据(Load JSON Entries)\n", "\n", "- 接下来,我们将 **加载并处理指令数据集**。 \n", "- 这里假设我们已将 **测试数据集** 及 **模型生成的响应** 以 **JSON 文件** 形式保存,并可按以下方式加载: " ] }, { "cell_type": "code", "execution_count": 5, "id": "8b2d393a-aa92-4190-9d44-44326a6f699b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of entries: 1100\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "\n", "json_file = Path(\"..\") / \"01_main-chapter-code\" / \"instruction-data.json\"\n", "\n", "with open(json_file, \"r\") as file:\n", " json_data = json.load(file)\n", "\n", "print(\"Number of entries:\", len(json_data))" ] }, { "cell_type": "markdown", "id": "b6c9751b-59b7-43fe-acc7-14e8daf2fa66", "metadata": {}, "source": [ "- 让我们打印 **数据集中的一个条目**,以查看其结构: " ] }, { "cell_type": "code", "execution_count": 6, "id": "ce187422-a4e6-4f3c-b0d1-b03257f5bcdd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'instruction': 'Evaluate the following phrase by transforming it into the '\n", " 'spelling given.',\n", " 'input': 'freind --> friend',\n", " 'output': 'The spelling of the given phrase \"freind\" is incorrect, the '\n", " 'correct spelling is \"friend\".'}\n" ] } ], "source": [ "from pprint import pp as pprint\n", "\n", "pprint(json_data[0])" ] }, { "cell_type": "markdown", "id": "3fce41e0-433f-49aa-82b7-a9d1a1d41604", "metadata": { "tags": [] }, "source": [ "## 优化指令(Improve Instructions)\n", "\n", "- **Reflection-Tuning** 作者提出了 **两种优化方法**:\n", " 1. **优化指令(Instructions)**\n", " 2. **优化响应(Responses)** \n", "- 让我们首先 **优化数据集中的指令**。 \n", "- 下面是 **[Reflection-Tuning 仓库](https://github.com/tianyi-lab/Reflection_Tuning/blob/main/reflection_code/reflect_response.py)** 提供的 **工具函数**,用于 **格式化输入** 以适配 **GPT-4** 进行数据优化。 " ] }, { "cell_type": "code", "execution_count": 7, "id": "76d28ada-9e1a-4818-8a49-82be44141533", "metadata": {}, "outputs": [], "source": [ "def instr_prompt_no_input(ins, outp):\n", "\n", " sys_prompt = \"You are a helpful, precise but picky assistant for checking the quality of a given instruction.\"\n", " prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n", " criteria = \"We would like you to answer several questions related to the quality of a given instruction. \\n\" + \\\n", " \"1. Why this instruction is not good? First analyse the instruction based on Complexity of the Topic, Level of Detail Required, Knowledge Required, Ambiguity of the Instruction and Logical Reasoning or Problem-Solving Involved. \\n\" + \\\n", " \"Then analyse why this answer is not good for the given instruction? Analyse based on the Helpfulness, Relevance, Accuracy and Level of Details. \\n\" + \\\n", " \"Finally analyse why this bad instruction lead to a bad answer. \" +\\\n", " \"2. Based on the reason you provided, generate a new and complete instruction which is complex and difficult to answer directly. \" + \\\n", " \"Make sure the new instruction is relevent but independent to the original instruction, which can be answered without knowing the original instruction, put the new instruction in the format of [New Instruction] your instruction [End]\" +\\\n", " \"3. Answer the newly generated instruction as detailed as possible, in the format of [New Answer] your answer [End] \\n\"\n", " prompt = prompt_template.format(\n", " ins=ins, outp=outp, criteria=criteria\n", " )\n", " return sys_prompt, prompt" ] }, { "cell_type": "markdown", "id": "74dd8a16-3f96-4662-b8c3-97dced794c6c", "metadata": {}, "source": [ "- 让我们以 **数据集条目 `json_data[2]`** 为例,查看优化过程。 " ] }, { "cell_type": "code", "execution_count": 8, "id": "203807bf-9a76-4e12-b801-b4ae518f30a2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'instruction': 'Convert 45 kilometers to meters.',\n", " 'input': '',\n", " 'output': '45 kilometers is 45000 meters.'}\n" ] } ], "source": [ "pprint(json_data[2])" ] }, { "cell_type": "markdown", "id": "9572a1aa-532a-4a76-9fa3-3b59d996ba13", "metadata": {}, "source": [ "- 可以使用 **上述 `instr_prompt_no_input` 函数** 来优化指令,如下所示: " ] }, { "cell_type": "code", "execution_count": null, "id": "4a7a0a76-cc22-4bda-ae26-4b2540afb4cf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1. **Analysis of the Instruction:**\n", "\n", " - **Complexity of the Topic:** The topic of converting kilometers to meters is relatively simple and straightforward, as it involves basic unit conversion.\n", " - **Level of Detail Required:** The instruction does not require much detail; it simply asks for a conversion without any additional context or explanation.\n", " - **Knowledge Required:** Basic knowledge of metric units and their conversions is required, which is common knowledge.\n", " - **Ambiguity of the Instruction:** The instruction is clear and unambiguous; it specifies exactly what needs to be converted.\n", " - **Logical Reasoning or Problem-Solving Involved:** There is minimal logical reasoning involved, as the conversion factor (1 kilometer = 1000 meters) is a standard fact.\n", "\n", " **Analysis of the Answer:**\n", "\n", " - **Helpfulness:** The answer is helpful in that it provides the correct conversion.\n", " - **Relevance:** The answer is relevant to the instruction, as it directly addresses the conversion requested.\n", " - **Accuracy:** The answer is accurate; 45 kilometers does indeed equal 45,000 meters.\n", " - **Level of Details:** The answer lacks detail. It does not explain the conversion process or provide any context, which could be beneficial for someone unfamiliar with metric conversions.\n", "\n", " **Why the Bad Instruction Leads to a Bad Answer:** While the instruction itself is not bad, the simplicity of the task may lead to a lack of depth in the answer. The answer could have been improved by including an explanation of the conversion process, which would enhance understanding.\n", "\n", "2. **New Instruction:**\n", " [New Instruction] Explain the significance of the metric system in global trade and provide examples of how unit conversions can impact international business transactions. [End]\n", "\n", "3. **New Answer:**\n", " [New Answer] The metric system, also known as the International System of Units (SI), is a decimal-based system of measurement that is used globally. Its significance in global trade lies in its standardization, which facilitates international communication and commerce. \n", "\n", " One of the primary advantages of the metric system is that it is universally recognized, which reduces confusion and errors in measurement. For example, when a company in the United States imports goods from Europe, the specifications for those goods are often provided in metric units. If the U.S. company is accustomed to using imperial units (like inches or pounds), they must convert these measurements to ensure compatibility. \n", "\n", " Unit conversions can significantly impact international business transactions. For instance, if a manufacturer orders 100 kilograms of a product but mistakenly interprets it as 100 pounds, they will receive a much smaller quantity than intended, leading to production delays and financial losses. \n", "\n", " Additionally, in industries such as pharmaceuticals, precise measurements are critical. A dosage specified in milligrams must be accurately converted to ensure patient safety. \n", "\n", " In summary, the metric system's role in global trade is crucial for maintaining consistency and accuracy in measurements, which ultimately supports efficient and effective international business operations. [End]\n" ] } ], "source": [ "entry = json_data[2]\n", "\n", "system_prompt, prompt = instr_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n", "output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n", "\n", "print(output)" ] }, { "cell_type": "markdown", "id": "5fdc4575-caae-45cc-bf9c-fde9322cf3df", "metadata": {}, "source": [ "- **生成的响应较为冗长**,这对于 **分析** 很有帮助,同时也能利用 **思维链(Chain-of-Thought)提示** 让 **GPT-4** 进行更有效的优化。 \n", "- 但是,在构造 **优化后数据集** 时,我们实际只需要 **新的指令和输出**,而 **不需要分析部分**。 \n", "- 我们可以使用 **[Reflection-Tuning 仓库](https://github.com/tianyi-lab/Reflection_Tuning/blob/main/reflection_code/reflect_response.py)** 中的 **工具代码**,提取 **模型优化后的指令和输出**: " ] }, { "cell_type": "code", "execution_count": null, "id": "bb38d406-69a5-448a-8d20-bd48c47eb485", "metadata": {}, "outputs": [], "source": [ "import re\n", "\n", "def extract_ins(text, no_input=True):\n", " if '[New Instruction]' in text:\n", " pattern = r'(\\[New Instruction\\])(.*?)(\\[End\\]|\\[New Answer\\]|New Answer:)'\n", " else:\n", " pattern = r'(New Instruction:)(.*?)(\\[End\\]|\\[New Answer\\]|New Answer:)'\n", " segments = re.findall(pattern, text, re.DOTALL)\n", " if len(segments) == 0:\n", " seg_ins = ''\n", " else:\n", " seg_ins = segments[0][1].strip()\n", " if seg_ins.endswith(\"\\n\\n3.\"):\n", " seg_ins = seg_ins[:-4]\n", " return seg_ins\n", "\n", "\n", "def extract_oup(text, no_input=True):\n", " if '[New Answer]' in text:\n", " pattern = r'(\\[New Answer\\])(.*?)(\\[End\\]|$)'\n", " else:\n", " pattern = r'(New Answer:)(.*?)(\\[End\\]|$)'\n", " # pattern = r'(\\[New Answer\\]|New Answer:)(.*?)(\\[End\\]|$)'\n", " segments = re.findall(pattern, text, re.DOTALL)\n", " if len(segments) == 0:\n", " seg_oup = ''\n", " else:\n", " seg_oup = segments[0][1].strip()\n", " return seg_oup\n", "\n", "\n", "def extract_instruction(text):\n", " if text == '':\n", " return []\n", " seg_ins = extract_ins(text, no_input=True)\n", " seg_oup = extract_oup(text, no_input=True)\n", " return [seg_ins, seg_oup]" ] }, { "cell_type": "markdown", "id": "5eaacf01-6f00-4fa6-9f2c-cf688d58237a", "metadata": {}, "source": [ "- 现在,我们使用 **这些工具函数** 从 **GPT-4 生成的详细输出** 中 **提取优化后的指令与响应**: " ] }, { "cell_type": "code", "execution_count": null, "id": "9699b79b-959e-492e-9fad-f7e451d56777", "metadata": {}, "outputs": [], "source": [ "new_instr, new_outp = extract_instruction(output)" ] }, { "cell_type": "code", "execution_count": null, "id": "2cad89ee-9c63-42c8-a113-a98b13a1fbe2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explain the significance of the metric system in global trade and provide examples of how unit conversions can impact international business transactions.\n" ] } ], "source": [ "print(new_instr)" ] }, { "cell_type": "code", "execution_count": null, "id": "3bcf3fb6-f572-44ea-aea6-17f64288fd62", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The metric system, also known as the International System of Units (SI), is a decimal-based system of measurement that is used globally. Its significance in global trade lies in its standardization, which facilitates international communication and commerce. \n", "\n", " One of the primary advantages of the metric system is that it is universally recognized, which reduces confusion and errors in measurement. For example, when a company in the United States imports goods from Europe, the specifications for those goods are often provided in metric units. If the U.S. company is accustomed to using imperial units (like inches or pounds), they must convert these measurements to ensure compatibility. \n", "\n", " Unit conversions can significantly impact international business transactions. For instance, if a manufacturer orders 100 kilograms of a product but mistakenly interprets it as 100 pounds, they will receive a much smaller quantity than intended, leading to production delays and financial losses. \n", "\n", " Additionally, in industries such as pharmaceuticals, precise measurements are critical. A dosage specified in milligrams must be accurately converted to ensure patient safety. \n", "\n", " In summary, the metric system's role in global trade is crucial for maintaining consistency and accuracy in measurements, which ultimately supports efficient and effective international business operations.\n" ] } ], "source": [ "print(new_outp)" ] }, { "cell_type": "markdown", "id": "5dec63dc-eaf1-4bcf-87ab-c63b2924cc67", "metadata": {}, "source": [ "- **注意**:目前,**指令优化(instruction-refinement)** 仅适用于 **数据集中不包含 `\"input\"` 字段** 的条目。 " ] }, { "cell_type": "markdown", "id": "978b1559-61c9-4ab9-a353-4962e4ec6d38", "metadata": {}, "source": [ "## 优化响应(Improve Responses)\n", "\n", "- 同样,我们可以将 **反思微调(Reflection-Tuning)** 过程 **应用于数据集的响应部分**(即 `\"output\"` 字段)。 \n", "- 下面是 **[Reflection-Tuning 仓库](https://github.com/tianyi-lab/Reflection_Tuning/blob/main/reflection_code/reflect_response.py)** 提供的 **两个工具函数**,用于格式化输入,使其适配 **GPT-4 模型** 进行数据优化。 " ] }, { "cell_type": "code", "execution_count": null, "id": "5f78806c-abc1-4f38-afc9-9582bb48b668", "metadata": {}, "outputs": [], "source": [ "def res_gen_prompt_no_input(ins, outp):\n", "\n", " sys_prompt = \"You are a helpful, precise but picky assistant for checking the quality of the answer to a given instruction.\"\n", " prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n", " criteria = \"We would like you to answer several questions related to the quality of the answer to the given instruction. \\n\" + \\\n", " \"1. Why this answer is not good for the given instruction? Analyse based on the Helpfulness, Relevance, Accuracy and Level of Details. \\n\" + \\\n", " \"2. Based on the reason you provided, generate a better answer, new and complete, as detailed as possible, in the format of [Better Answer] your answer [End] \\n\" \n", " prompt = prompt_template.format(\n", " ins=ins, outp=outp, criteria=criteria\n", " )\n", " return sys_prompt, prompt\n", "\n", "\n", "def res_gen_prompt_input(ins, inp, outp):\n", "\n", " sys_prompt = \"You are a helpful and precise assistant for checking the quality of the answer to a given instruction and its input.\"\n", " prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Input]\\n{inp}\\n\\n[The End of Input]\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n", " criteria = \"We would like you to answer several questions related to the quality of the answer to the given instruction and corresponding input. \\n\" + \\\n", " \"1. Why this answer is not good for the given instruction and corresponding input? Analyse based on the Helpfulness, Relevance, Accuracy and Level of Details. \\n\" + \\\n", " \"2. Based on the reason you provided, generate a better answer, new and complete, as detailed as possible, in the format of [Better Answer] your answer [End] \\n\" \n", " prompt = prompt_template.format(\n", " ins=ins, inp=inp, outp=outp, criteria=criteria\n", " )\n", " return sys_prompt, prompt" ] }, { "cell_type": "markdown", "id": "39a55283-7d51-4136-ba60-f799d49f4098", "metadata": {}, "source": [ "- 现在,我们对 **数据集中的一条样本** 进行 **反思微调**,生成 **优化后的响应**,以观察其效果: " ] }, { "cell_type": "code", "execution_count": 15, "id": "126c4aa3-687c-4328-b174-84f1078cac72", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1. The answer provided is not good for the given instruction for several reasons:\n", "\n", "- **Helpfulness**: While the answer does provide the correct conversion, it lacks any explanation or context. A more helpful answer would include a brief explanation of the conversion process, which would aid understanding.\n", "\n", "- **Relevance**: The answer is relevant in that it addresses the instruction to convert kilometers to meters, but it could be more relevant by including the conversion factor used (1 kilometer = 1000 meters).\n", "\n", "- **Accuracy**: The answer is accurate in terms of the numerical conversion (45 kilometers = 45000 meters). However, it could be misleading if the reader does not understand how the conversion was derived.\n", "\n", "- **Level of Details**: The answer is very brief and lacks detail. A more detailed response would include the conversion factor and a step-by-step explanation of how the conversion is performed.\n", "\n", "2. [Better Answer] To convert kilometers to meters, you can use the conversion factor that 1 kilometer is equal to 1000 meters. Therefore, to convert 45 kilometers to meters, you multiply 45 by 1000. \n", "\n", "So, 45 kilometers × 1000 meters/kilometer = 45000 meters. \n", "\n", "Thus, 45 kilometers is equal to 45000 meters. [End]\n" ] } ], "source": [ "entry = json_data[2]\n", "\n", "system_prompt, prompt = res_gen_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n", "output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n", "\n", "print(output)" ] }, { "cell_type": "markdown", "id": "c206abe9-5a64-4532-90d5-661d63670531", "metadata": {}, "source": [ "- 如上所示,**生成的响应** 包含了对 **原始响应的分析**; \n", "- 我们可以使用 **[Reflection-Tuning 仓库](https://github.com/tianyi-lab/Reflection_Tuning/blob/main/reflection_code/reflect_response.py)** 提供的 **工具函数** 提取 **优化后的新响应**: " ] }, { "cell_type": "code", "execution_count": 16, "id": "164cb816-f7dd-4399-a0be-300f4518cf8c", "metadata": {}, "outputs": [], "source": [ "def extract_response(text):\n", " if text.count('[Better Answer]') >= 2:\n", " pattern = r'\\[(Better Answer)\\](.*?)(\\[End\\]|\\[Better Answer\\]|$)'\n", " segments = re.findall(pattern, text, re.DOTALL)\n", " else:\n", " # pattern = r'\\[(Better Answer)\\](.*?)\\[End\\]'\n", " pattern = r'\\[(Better Answer)\\](.*?)(\\[End\\]|End|$)'\n", " segments = re.findall(pattern, text, re.DOTALL)\n", " return [segment[1].strip() for segment in segments]" ] }, { "cell_type": "code", "execution_count": 17, "id": "95174f90-dc02-483d-a335-8b448d1b1e22", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "To convert kilometers to meters, you can use the conversion factor that 1 kilometer is equal to 1000 meters. Therefore, to convert 45 kilometers to meters, you multiply 45 by 1000. \n", "\n", "So, 45 kilometers × 1000 meters/kilometer = 45000 meters. \n", "\n", "Thus, 45 kilometers is equal to 45000 meters.\n" ] } ], "source": [ "response = extract_response(output)[0]\n", "print(response)" ] }, { "cell_type": "markdown", "id": "cdf583fb-7e18-4b84-89dc-1c5d162c67ea", "metadata": {}, "source": [ "## 提升数据集" ] }, { "cell_type": "markdown", "id": "142dfaa7-429f-4eb0-b74d-ff327f79547a", "metadata": {}, "source": [ "- 现在,让我们将 **指令反思(instruction-reflection)** 和 **响应反思(response-reflection)** 技术 **应用到实际数据集**。 \n", "- **注意**:此处仅对 **小规模数据子集** 进行 **演示**,如果要应用到 **完整数据集**,请修改相关参数。 \n", "\n", "```python\n", "data_to_process = json_data[:3]\n", "```\n", "\n", "to\n", "\n", "```python\n", "data_to_process = json_data\n", "```" ] }, { "cell_type": "markdown", "id": "5deb631a-cde5-4f5c-8eae-0065b4723abb", "metadata": {}, "source": [ "### Reflect instructions" ] }, { "cell_type": "markdown", "id": "d4cc4fb6-9c95-4999-ba1f-c333e701d779", "metadata": {}, "source": [ "- 以下代码使用 **反思微调(Reflection-Tuning)** 方法,对 **原始数据集中的指令** 进行 **优化与精炼**。 " ] }, { "cell_type": "code", "execution_count": 18, "id": "9a4564aa-2d3e-46a8-a339-295c5ff177b6", "metadata": {}, "outputs": [], "source": [ "data_to_process = json_data[:3]" ] }, { "cell_type": "code", "execution_count": 19, "id": "3552bdfb-7511-42ac-a9ec-da672e2a5468", "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "\n", "\n", "def reflect_instructions(json_data, client):\n", " new_json_data = [] \n", " \n", " for entry in tqdm(json_data):\n", " \n", " if not entry[\"input\"]:\n", " system_prompt, prompt = instr_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n", " output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n", " new_instr, new_outp = extract_instruction(output)\n", " new_entry = {\"instruction\": new_instr, \"input\": \"\", \"output\": new_outp}\n", " new_json_data.append(new_entry)\n", " else:\n", " new_json_data.append(entry)\n", "\n", " return new_json_data" ] }, { "cell_type": "code", "execution_count": 20, "id": "d897eda7-ebd6-4a09-a3ae-8d05a2f234dd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████| 3/3 [00:06<00:00, 2.17s/it]\n" ] } ], "source": [ "data_to_process = json_data[:3]\n", "\n", "new_json_data = reflect_instructions(data_to_process, client)" ] }, { "cell_type": "code", "execution_count": 21, "id": "d1a677a2-d590-4ffb-a202-5fe79a317d4b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'instruction': 'Evaluate the following phrase by transforming it into the '\n", " 'spelling given.',\n", " 'input': 'freind --> friend',\n", " 'output': 'The spelling of the given phrase \"freind\" is incorrect, the '\n", " 'correct spelling is \"friend\".'}\n", "\n", "\n", "\n", "{'instruction': 'Edit the following sentence for grammar.',\n", " 'input': 'He go to the park every day.',\n", " 'output': 'He goes to the park every day.'}\n", "\n", "\n", "\n", "{'instruction': 'Explain the significance of understanding metric conversions '\n", " 'in scientific research, and provide an example of how a '\n", " 'miscalculation in unit conversion could impact experimental '\n", " 'results.',\n", " 'input': '',\n", " 'output': 'Understanding metric conversions is crucial in scientific research '\n", " 'because accurate measurements are fundamental to the validity of '\n", " 'experimental results. The metric system is widely used in '\n", " 'scientific disciplines due to its ease of use and universal '\n", " 'acceptance, allowing researchers from different countries to '\n", " 'communicate their findings effectively.\\n'\n", " '\\n'\n", " ' For example, consider a scenario in a chemistry experiment '\n", " 'where a researcher needs to prepare a solution with a specific '\n", " 'concentration. If the researcher intends to prepare a 1 molar (1 '\n", " 'M) solution of sodium chloride (NaCl) in 1 liter of water, they '\n", " 'must accurately measure the mass of NaCl required. The molar mass '\n", " 'of NaCl is approximately 58.44 grams per mole. Therefore, to '\n", " 'prepare 1 liter of a 1 M solution, the researcher needs to '\n", " 'dissolve 58.44 grams of NaCl in water.\\n'\n", " '\\n'\n", " ' However, if the researcher mistakenly converts the volume from '\n", " 'liters to milliliters and uses 1 mL instead of 1 L, they would '\n", " 'only need 0.05844 grams of NaCl. This significant error in unit '\n", " 'conversion would lead to a solution that is 1,000 times more '\n", " 'concentrated than intended. Such a miscalculation could result in '\n", " 'erroneous experimental outcomes, potentially leading to incorrect '\n", " 'conclusions about the behavior of the solution in reactions or '\n", " 'biological systems. This example highlights the importance of '\n", " 'precise unit conversions in ensuring the accuracy and reliability '\n", " 'of scientific research.'}\n", "\n", "\n", "\n" ] } ], "source": [ "for i in new_json_data[:3]:\n", " pprint(i)\n", " print(\"\\n\\n\")" ] }, { "cell_type": "markdown", "id": "17840244-a7f9-47e4-8551-671fdedfc856", "metadata": {}, "source": [ "- 保存新数据集:" ] }, { "cell_type": "code", "execution_count": 22, "id": "f9710e60-6c3a-42ab-ab66-24005db2e19f", "metadata": {}, "outputs": [], "source": [ "with open(\"instruction-reflected.json\", \"w\") as file:\n", " json.dump(new_json_data, file, indent=4)" ] }, { "cell_type": "markdown", "id": "3a081ff5-7aa7-4651-934a-34ce56b7ee5e", "metadata": {}, "source": [ "### 响应反思调整(Reflecting on Responses)\n", "\n", "- 现在,我们对 **模型生成的响应** 进行 **反思优化**(response-reflection): " ] }, { "cell_type": "code", "execution_count": 23, "id": "835da869-965a-4a4c-9799-56dbcd559d83", "metadata": {}, "outputs": [], "source": [ "data_to_process = json_data[:3]" ] }, { "cell_type": "code", "execution_count": 24, "id": "38f436b6-1b6c-45e7-a538-a47021607ea5", "metadata": {}, "outputs": [], "source": [ "def reflect_responses(json_data, client):\n", " new_json_data = [] \n", " \n", " for entry in tqdm(json_data):\n", " \n", " if not entry[\"input\"]:\n", " system_prompt, prompt = res_gen_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n", " output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n", " new_response = extract_response(output)\n", "\n", " if not len(new_response):\n", " new_response = entry[\"output\"]\n", " \n", " new_entry = {\"instruction\": entry[\"instruction\"], \"input\": \"\", \"output\": new_response[0]}\n", " new_json_data.append(new_entry)\n", "\n", " else:\n", " system_prompt, prompt = res_gen_prompt_input(ins=entry[\"instruction\"], inp=entry[\"input\"], outp=entry[\"output\"])\n", " output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n", " new_response = extract_response(output)\n", "\n", " if not len(new_response):\n", " new_response = entry[\"output\"]\n", "\n", " new_entry = {\"instruction\": entry[\"instruction\"], \"input\": entry[\"input\"], \"output\": new_response[0]}\n", " new_json_data.append(new_entry)\n", "\n", " return new_json_data" ] }, { "cell_type": "code", "execution_count": 25, "id": "0168fb7e-bef4-43e1-967d-f294909b6883", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████| 3/3 [00:07<00:00, 2.40s/it]\n" ] } ], "source": [ "new_json_data = reflect_responses(data_to_process, client)" ] }, { "cell_type": "code", "execution_count": 26, "id": "1a0949dc-70f3-4adb-9d0a-7f387c0702c7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'instruction': 'Evaluate the following phrase by transforming it into the '\n", " 'spelling given.',\n", " 'input': 'freind --> friend',\n", " 'output': 'The input phrase \"freind\" contains a spelling error. The correct '\n", " 'transformation of the word is as follows: \"freind\" should be '\n", " 'corrected to \"friend.\" Therefore, the correct spelling is '\n", " '\"friend.\"'}\n", "\n", "\n", "\n", "{'instruction': 'Edit the following sentence for grammar.',\n", " 'input': 'He go to the park every day.',\n", " 'output': 'The original sentence \"He go to the park every day\" contains a '\n", " 'grammatical error in the verb form. The correct form should be \"He '\n", " 'goes to the park every day.\" This is because the subject \"He\" is '\n", " 'third person singular, and in English, the verb \"to go\" changes to '\n", " '\"goes\" when used with third person singular subjects. Therefore, '\n", " 'the corrected sentence is grammatically accurate and maintains the '\n", " 'original meaning.'}\n", "\n", "\n", "\n", "{'instruction': 'Convert 45 kilometers to meters.',\n", " 'input': '',\n", " 'output': 'To convert kilometers to meters, you can use the conversion factor '\n", " 'that 1 kilometer is equal to 1,000 meters. Therefore, to convert '\n", " '45 kilometers to meters, you multiply 45 by 1,000. \\n'\n", " '\\n'\n", " 'So, 45 kilometers is equal to 45,000 meters (45 km × 1,000 m/km = '\n", " '45,000 m). \\n'\n", " '\\n'\n", " 'This conversion is useful in various contexts, such as distance '\n", " 'measurement in travel or scientific calculations.'}\n", "\n", "\n", "\n" ] } ], "source": [ "for i in new_json_data[:3]:\n", " pprint(i)\n", " print(\"\\n\\n\")" ] }, { "cell_type": "markdown", "id": "9603159d-e9fa-42cd-a5ab-b528e534e103", "metadata": {}, "source": [ "- 保存最新的数据集:" ] }, { "cell_type": "code", "execution_count": 27, "id": "9e763966-6a43-4706-879d-1f413a85ec03", "metadata": {}, "outputs": [], "source": [ "with open(\"response-reflected.json\", \"w\") as file:\n", " json.dump(new_json_data, file, indent=4)" ] }, { "cell_type": "markdown", "id": "52dc19b4-926b-4497-8370-496efd970366", "metadata": {}, "source": [ "## 创建改进的指令数据集\n", "\n", "- **应用上述两种方法** 到 **第 7 章指令数据集** 的 **1100 条数据**,成本约 **$0.60(60 美分)**。 \n", "- **为了避免 GitHub 仓库体积过大**,处理后的数据集文件可从 **Google Drive** 下载:\n", " - [instruction-reflected.json](https://drive.google.com/file/d/1c1QnuTdt9nP1u51vBn4_b05mWR_ZNGBv/view?usp=sharing) \n", " - [response-reflected.json](https://drive.google.com/file/d/1RNckTZ2ELcdUoJtaylao6NvyZPMtNv1v/view?usp=sharing) " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 5 }