{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Simple visualizer for log files written by the training loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def parse_logfile(logfile):\n", " # so the tricky part we have to deal with in these log files\n", " # is that the job could crash and get restarted, which will\n", " # re-wind back and start re-logging older steps. So we keep\n", " # all the data as dictionary and over-write old data with new\n", " # and then at the end compile everything together\n", "\n", " # read raw data\n", " streams = {} # stream:str -> {step: val}\n", " with open(logfile, \"r\") as f:\n", " for line in f:\n", " parts = line.split()\n", " step = int(parts[0].split(\":\")[1])\n", " stream = parts[1].split(\":\")[0]\n", " val = float(parts[1].split(\":\")[1])\n", " if not stream in streams:\n", " streams[stream] = {}\n", " d = streams[stream]\n", " d[step] = val\n", " # now re-represent as list of (step, val) tuples\n", " streams_xy = {}\n", " for k, v in streams.items():\n", " # get all (step, val) items, sort them\n", " xy = sorted(list(v.items()))\n", " # unpack the list of tuples to tuple of lists\n", " streams_xy[k] = zip(*xy)\n", " # return the xs, ys lists\n", " return streams_xy\n", "\n", "parse_logfile(\"../log124M/main.log\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "sz = \"124M\"\n", "loss_baseline = {\n", " \"124M\": 3.424958,\n", " \"350M\": 3.083089,\n", " \"774M\": 3.000580,\n", " \"1558M\": 2.831273,\n", "}[sz]\n", "hella2_baseline = { # for GPT-2\n", " \"124M\": 0.294463,\n", " \"350M\": 0.375224,\n", " \"774M\": 0.431986,\n", " \"1558M\": 0.488946,\n", "}[sz]\n", "hella3_baseline = { # for GPT-3\n", " \"124M\": 0.337,\n", " \"350M\": 0.436,\n", " \"774M\": 0.510,\n", " \"1558M\": 0.547,\n", "}[sz]\n", "# assumes each model run is stored in this way\n", "logfile = f\"../log_gpt2_{sz}/main.log\"\n", "streams = parse_logfile(logfile)\n", "\n", "# optional function that smooths out the loss some\n", "def smooth_moving_average(signal, window_size):\n", " if signal.ndim != 1:\n", " raise ValueError(\"smooth_moving_average only accepts 1D arrays.\")\n", " if signal.size < window_size:\n", " raise ValueError(\"Input vector needs to be bigger than window size.\")\n", " if window_size < 3:\n", " return signal\n", "\n", " s = np.pad(signal, (window_size//2, window_size-1-window_size//2), mode='edge')\n", " w = np.ones(window_size) / window_size\n", " smoothed_signal = np.convolve(s, w, mode='valid')\n", " return smoothed_signal\n", "\n", "plt.figure(figsize=(16, 6))\n", "\n", "# Panel 1: losses: both train and val\n", "plt.subplot(121)\n", "xs, ys = streams[\"trl\"] # training loss\n", "ys = np.array(ys)\n", "# smooth out ys using a rolling window\n", "# ys = smooth_moving_average(ys, 21) # optional\n", "plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\n", "print(\"Min Train Loss:\", min(ys))\n", "xs, ys = streams[\"tel\"] # validation loss\n", "plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\n", "# horizontal line at GPT-2 baseline\n", "# we don't have GPT-3 loss on this dataset because the weights were never released\n", "if loss_baseline is not None:\n", " plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n", "plt.xlabel(\"steps\")\n", "plt.ylabel(\"loss\")\n", "plt.yscale('log')\n", "plt.ylim(top=4.0)\n", "plt.legend()\n", "plt.title(\"Loss\")\n", "print(\"Min Validation Loss:\", min(ys))\n", "\n", "# Panel 2: HellaSwag eval\n", "plt.subplot(122)\n", "if \"eval\" in streams:\n", " xs, ys = streams[\"eval\"] # HellaSwag eval\n", " ys = np.array(ys)\n", " plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n", " # horizontal line at GPT-2/3 baselines\n", " if hella2_baseline:\n", " plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n", " if hella3_baseline:\n", " plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f\"OpenAI GPT-3 ({sz}) checkpoint\")\n", " plt.xlabel(\"steps\")\n", " plt.ylabel(\"accuracy\")\n", " plt.legend()\n", " plt.title(\"HellaSwag eval\")\n", " print(\"Max Hellaswag eval:\", max(ys))\n" ] } ], "metadata": { "kernelspec": { "display_name": "pytorch3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }