{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference on shared models\n", "\n", "Run inference on the shared sigmoid-attention models (and its softmax baseline) using AXLearn.\n", "\n", "### Setup\n", "\n", "Please see [Getting Started](https://github.com/apple/axlearn/blob/main/README.md) for the general getting started guide for AXLearn. Including how to define new trainer configs yourself, how to use the CLI, how to launch trainings and why certain design decisions were made.\n", "\n", "To install required dependencies for this notebook, run:\n", "```shell\n", "pip install --ignore-installed \"axlearn[core,apple-silicon,gcp] @ git+https://github.com/apple/axlearn.git\"\n", "\n", "pip install tabulate\n", "```\n", "\n", "And then, make sure to authenticate to GCP: `gcloud auth login && gcloud auth application-default login`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# General imports.\n", "from typing import List, Iterator, Tuple, Literal, Sequence\n", "\n", "# JAX must be imported before tensorflow.\n", "import jax\n", "import jax.numpy as jnp\n", "import seqio\n", "from tqdm import tqdm\n", "import os\n", "import subprocess\n", "from pathlib import Path\n", "\n", "# Necessary so that the checkpoint loader works inside a notebook.\n", "import nest_asyncio\n", "\n", "nest_asyncio.apply()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading state for 7b-sigmoid from:\n", "downloads/models/gala-7B-sigmoid-hybridnorm-alibi-sprp-2024-12-03-1002/checkpoints/step_00250000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Multiple positional arguments for .source_to_target. Consider using keyword arguments instead.\n" ] } ], "source": [ "# The LM configurations and checkpoints.\n", "from axlearn.experiments import get_named_trainer_config\n", "from axlearn.common.config import config_for_function\n", "\n", "# Stuff to configure the local device mesh.\n", "from axlearn.common import utils_spmd\n", "\n", "# To control what to load from the saved checkpoint.\n", "from axlearn.common import state_builder\n", "from axlearn.common.checkpointer import CheckpointValidationType\n", "\n", "# For the tokenizer/vocab.\n", "from axlearn.experiments.text import common\n", "\n", "# For typing stuff.\n", "from axlearn.common.utils import DataPartitionType, set_data_dir\n", "from axlearn.common import utils\n", "from axlearn.common.inference import InferenceRunner\n", "\n", "\n", "JAX_BACKEND: Literal[\"cpu\", \"tpu\", \"gpu\"] = \"cpu\"\n", "DATA_DIR: str = \"gs://axlearn-public/tensorflow_datasets\"\n", "\n", "REMOTE_MODEL_DIR = \"gs://axlearn-public/experiments/\"\n", "LOCAL_MODEL_DIR = \"downloads/models/\"\n", "MODEL_INFO: dict[str, dict[str, str]] = {\n", " # Sigmoid-based attention.\n", " \"7b-sigmoid\": {\n", " \"checkpoint_dir\": \"gala-7B-sigmoid-hybridnorm-alibi-sprp-2024-12-03-1002/checkpoints/step_00250000\",\n", " \"config_name\": \"gala-sigmoid-7B-4k-hybridnorm-alibi-sp-rp\",\n", " \"sentencepiece_model_name\": \"bpe_32k_c4.model\",\n", " \"config_module\": \"axlearn.experiments.text.gpt.pajama_sigmoid_trainer\",\n", " },\n", " # Softmax baseline.\n", " \"7b-softmax\": {\n", " \"checkpoint_dir\": \"gala-7B-hybridnorm-alibi-sprp-2024-12-02-1445/checkpoints/step_00250000\",\n", " \"config_name\": \"gala-7B-hybridnorm-alibi-flash-sp-rp\",\n", " \"sentencepiece_model_name\": \"bpe_32k_c4.model\",\n", " \"config_module\": \"axlearn.experiments.text.gpt.pajama_trainer\",\n", " },\n", "}\n", "\n", "utils_spmd.setup(jax_backend=JAX_BACKEND)\n", "\n", "\n", "def _init_state_builder_discard_optimizer(\n", " *,\n", " source_config_name: str,\n", " source_config_module: str,\n", " mesh_axis_names: Sequence[str],\n", " mesh_shape: Sequence[int],\n", " checkpoint_dir: str,\n", ") -> state_builder.Builder.Config:\n", " converter = state_builder.ModelStateScopeConverter.default_config().set(\n", " source_trainer_config=config_for_function(get_named_trainer_config).set(\n", " config_name=source_config_name,\n", " config_module=source_config_module,\n", " ),\n", " # Only keep `decoder` tree, which means we throw away optimizer.\n", " scope={\"decoder\": \"decoder\"},\n", " mesh_axis_names=mesh_axis_names,\n", " mesh_shape=mesh_shape,\n", " )\n", " init_state_builder = state_builder.RestoreAndConvertBuilder.default_config().set(\n", " builder=state_builder.TensorStoreStateStorageBuilder.default_config().set(\n", " validation=CheckpointValidationType.CONTAINS_STATE_UP_TO_DTYPE,\n", " dir=checkpoint_dir,\n", " ),\n", " converter=converter,\n", " )\n", " return init_state_builder\n", "\n", "\n", "def get_inference_runner(name: str, param_dtype: jnp.dtype) -> InferenceRunner:\n", " \"\"\"Make an inference runner initialized with pre-trained state according to model name.\"\"\"\n", " # Get the trainer configuration by name.\n", " ckpt_dir = MODEL_INFO[name][\"checkpoint_dir\"]\n", "\n", " # If we don't have a local version, first download it.\n", " local_ckpt_dir = Path(LOCAL_MODEL_DIR) / ckpt_dir\n", " if not local_ckpt_dir.exists():\n", " remote_ckpt_dir = os.path.join(REMOTE_MODEL_DIR, ckpt_dir)\n", " print(f\"Copying checkpoint from {remote_ckpt_dir} to {local_ckpt_dir}.\")\n", " os.makedirs(local_ckpt_dir, exist_ok=True)\n", " os.makedirs(local_ckpt_dir / \"gda\", exist_ok=True)\n", " # We only copy the weights, not `gda/learner`, which contains the full optimizer state.\n", " subprocess.run([\"gsutil\", \"-m\", \"cp\", \"-r\", os.path.join(remote_ckpt_dir, \"tf_*\"), local_ckpt_dir])\n", " subprocess.run([\"gsutil\", \"-m\", \"cp\", \"-r\", os.path.join(remote_ckpt_dir, \"gda\", \"model\"), local_ckpt_dir / \"gda\"])\n", " subprocess.run([\"gsutil\", \"-m\", \"cp\", \"-r\", os.path.join(remote_ckpt_dir, \"gda\", \"prng_key\"), local_ckpt_dir / \"gda\"])\n", " subprocess.run([\"gsutil\", \"cp\", os.path.join(remote_ckpt_dir, \"index\"), local_ckpt_dir])\n", "\n", " config_name = MODEL_INFO[name][\"config_name\"]\n", " config_module = MODEL_INFO[name][\"config_module\"]\n", " mesh_axis_names = (\n", " \"data\",\n", " \"expert\",\n", " \"fsdp\",\n", " \"model\",\n", " \"seq\",\n", " )\n", " mesh_shape = (\n", " 1,\n", " 1,\n", " 1,\n", " len(jax.devices()),\n", " 1,\n", " )\n", "\n", " trainer_cfg = get_named_trainer_config(\n", " config_name=config_name,\n", " config_module=config_module,\n", " )()\n", "\n", " # Do not load optimizer state, to speed up loading.\n", " init_state_builder = _init_state_builder_discard_optimizer(\n", " source_config_name=config_name,\n", " source_config_module=config_module,\n", " mesh_axis_names=mesh_axis_names,\n", " mesh_shape=mesh_shape,\n", " checkpoint_dir=str(local_ckpt_dir),\n", " )\n", "\n", " inference_runner_cfg = InferenceRunner.default_config().set(\n", " name=f\"{name}_inference_runner\",\n", " mesh_axis_names=mesh_axis_names,\n", " mesh_shape=mesh_shape,\n", " model=trainer_cfg.model.set(dtype=param_dtype),\n", " input_batch_partition_spec=DataPartitionType.REPLICATED, # FULL, REPLICATED\n", " init_state_builder=init_state_builder,\n", " )\n", " print(f\"Loading state for {name} from:\\n{local_ckpt_dir}\")\n", " inference_runner = inference_runner_cfg.instantiate(parent=None)\n", " return inference_runner\n", "\n", "\n", "def get_vocab(name: str) -> seqio.Vocabulary:\n", " \"\"\"Get the vocabulary based on the model's name.\"\"\"\n", " with set_data_dir(DATA_DIR):\n", " vocab = common.vocab(\n", " sentencepiece_model_name=MODEL_INFO[name][\"sentencepiece_model_name\"]\n", " )\n", " return vocab\n", "\n", "\n", "# Load the model checkpoint.\n", "model_name = \"7b-sigmoid\" # \"7b-softmax\"\n", "inference_runner = get_inference_runner(model_name, param_dtype=jnp.bfloat16)\n", "vocab = get_vocab(model_name)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def _preprocess_text(text: str) -> str:\n", " \"\"\"Preprocesses text for tokenization.\n", "\n", " Our sentencepiece tokenizers have been trained to see in place of \\n.\n", " \"\\n\" will still work in most places as intended, but it's not always guaranteed\n", " to tokenize to an individual token, unlike .\n", " \"\"\"\n", " return text.replace(\"\\n\", \"\")\n", "\n", "\n", "def _postprocess_text(text: str) -> str:\n", " \"\"\"Postprocesses text after LM inference.\n", "\n", " Changes back to \\n. This is the opposite operation of preprocess_text.\n", " \"\"\"\n", " return text.replace(\"\", \"\\n\")\n", "\n", "\n", "def _preprocess_inference(\n", " contexts_list: List[str],\n", " vocab: seqio.Vocabulary,\n", " *,\n", " max_seq_len: int = 256,\n", " batch_size: int = 1,\n", ") -> Iterator[utils.NestedTensor]:\n", " batched_remainder = len(contexts_list) % batch_size\n", " if batched_remainder:\n", " contexts_list += [[\"\"]] * (batch_size - batched_remainder)\n", " for chunk_ix in range(0, len(contexts_list), batch_size):\n", " chunk = contexts_list[chunk_ix : chunk_ix + batch_size]\n", " output_buffer = []\n", " for context in chunk:\n", " output = vocab.encode(_preprocess_text(context))\n", " output += [vocab.pad_id] * (max_seq_len - len(output))\n", " output_buffer.append(output)\n", " yield dict(input_ids=jnp.asarray(output_buffer, dtype=jnp.int32))\n", "\n", "\n", "def postprocess_inference(\n", " output: utils.NestedTensor,\n", " vocab: seqio.Vocabulary,\n", ") -> List[Tuple[str, str]]:\n", " input_ids = utils.replicate_to_local_data(output[\"inputs\"][\"input_ids\"])\n", " # [batch_size, seq_len, vocab_size]\n", " predicted_logits = utils.replicate_to_local_data(output[\"outputs\"][\"logits\"])\n", " results = []\n", " for ix, in_val in enumerate(input_ids):\n", " if jnp.all(in_val == vocab.pad_id):\n", " # Skip fully padded input strings.\n", " continue\n", " if in_val[0] == vocab.eos_id:\n", " # Skip first EOS if it exists.\n", " in_val_eos_lstrip = in_val[1:]\n", " else:\n", " in_val_eos_lstrip = in_val\n", " input_str = _postprocess_text(vocab.decode(in_val_eos_lstrip))\n", " input_str_per_token = [_postprocess_text(vocab.decode([input_id])) for input_id in in_val]\n", " batch_predicted_logits = predicted_logits[ix]\n", " decoded_continuations = []\n", " speculative_continuation = \"\"\n", " speculative_continuation_per_token = []\n", " token_prediction_index = len(jnp.argwhere(in_val != vocab.pad_id)) - 1\n", " new_input_tokens = list(in_val[: token_prediction_index + 1])\n", " tokens = jnp.argmax(batch_predicted_logits, axis=-1)\n", " continuation_tokens = tokens\n", " decoded_continuations.append(_postprocess_text(vocab.decode(tokens)))\n", " decoded_continuations_per_token = [\n", " _postprocess_text(vocab.decode([token])) for token in tokens\n", " ]\n", " # Add next token info (for self speculative decoding setup).\n", " head_token_prediction = tokens[token_prediction_index]\n", " next_token_str = _postprocess_text(vocab.decode([head_token_prediction]))\n", " speculative_continuation += next_token_str\n", " speculative_continuation_per_token.append(next_token_str)\n", " new_input_tokens += [head_token_prediction]\n", " results.append(\n", " {\n", " \"input_str\": input_str,\n", " \"input_str_per_token\": input_str_per_token,\n", " \"decoded_continuations\": decoded_continuations,\n", " \"decoded_continuations_per_token\": decoded_continuations_per_token,\n", " \"input\": output[\"inputs\"][\"input_ids\"],\n", " \"output\": continuation_tokens,\n", " \"speculative_continuation\": speculative_continuation,\n", " \"speculative_continuation_per_token\": speculative_continuation_per_token,\n", " \"new_input_tokens\": new_input_tokens,\n", " \"new_input_str\": _postprocess_text(vocab.decode(new_input_tokens)),\n", " \"next_logits\": batch_predicted_logits[token_prediction_index],\n", " }\n", " )\n", " return results\n", "\n", "\n", "def decode_lm_response(\n", " vocab: seqio.Vocabulary,\n", " inference_runner: InferenceRunner,\n", " contexts_list: List[str],\n", " *,\n", " max_seq_len: int = 256,\n", " batch_size: int = 1,\n", " print_output: bool = False,\n", "):\n", " results = []\n", " for batch in inference_runner.run(\n", " _preprocess_inference(\n", " contexts_list=contexts_list,\n", " vocab=vocab,\n", " max_seq_len=max_seq_len,\n", " batch_size=batch_size,\n", " ),\n", " method=\"predict\",\n", " prng_key=jax.random.PRNGKey(11),\n", " ):\n", " for result in postprocess_inference(batch, vocab):\n", " results.append(result)\n", "\n", " if not print_output:\n", " return results\n", " # Print separately, so we always print it at the end of output in notebook.\n", " for result in results:\n", " from tabulate import tabulate\n", "\n", " per_token_table = [[\"Type\", *range(len(result[\"input\"][0]))]]\n", " per_token_table.append([\"input\"] + list(result[\"input\"][0]))\n", " per_token_table.append([\"input_str\"] + result[\"input_str_per_token\"])\n", " per_token_table.append([f\"prediction\"] + list(result[\"output\"]))\n", " per_token_table.append([f\"prediction_str\"] + result[\"decoded_continuations_per_token\"])\n", "\n", " # Also show empty chars.\n", " per_token_table_print = []\n", " for token_values in per_token_table:\n", " token_values_print = []\n", " for token in token_values:\n", " if isinstance(token, str):\n", " token = token.replace(\"\\n\", \"\")\n", " if token.isspace():\n", " token = f\"<{token}>\"\n", " token_values_print.append(token)\n", " per_token_table_print.append(token_values_print)\n", " print(tabulate(per_token_table_print, headers=\"firstrow\", tablefmt=\"grid\"))\n", "\n", " print()\n", " for k, v in result.items():\n", " print(f\"{k}: {v}\")\n", " return results" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Passing DataPartitionType is deprecated. Please specify a PartitionSpec directly.\n", "WARNING:absl:dispatch_input_batch is deprecated. Please use `axlearn.common.input_dispatch` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------------+----------+-----+------+--------+---------+---------+-------+------+------+\n", "| Type | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |\n", "+================+==========+=====+======+========+=========+=========+=======+======+======+\n", "| input | 4585 | 319 | 593 | 8152 | 262 | 14658 | 31905 | 1446 | 1137 |\n", "+----------------+----------+-----+------+--------+---------+---------+-------+------+------+\n", "| input_str | Life | is | like | riding | a | bicycle | . | To | keep |\n", "+----------------+----------+-----+------+--------+---------+---------+-------+------+------+\n", "| prediction | 10019 | 262 | 262 | 262 | 14658 | 31905 | 749 | 2254 | 380 |\n", "+----------------+----------+-----+------+--------+---------+---------+-------+------+------+\n", "| prediction_str | Sciences | a | a | a | bicycle | . | You | fall | your |\n", "+----------------+----------+-----+------+--------+---------+---------+-------+------+------+\n", "\n", "input_str: Life is like riding a bicycle. To keep\n", "input_str_per_token: ['Life', 'is', 'like', 'riding', 'a', 'bicycle', '.', 'To', 'keep']\n", "decoded_continuations: ['Sciences a a a bicycle. You fall your']\n", "decoded_continuations_per_token: ['Sciences', 'a', 'a', 'a', 'bicycle', '.', 'You', 'fall', 'your']\n", "input: [[ 4585 319 593 8152 262 14658 31905 1446 1137]]\n", "output: [10019 262 262 262 14658 31905 749 2254 380]\n", "speculative_continuation: your\n", "speculative_continuation_per_token: ['your']\n", "new_input_tokens: [4585, 319, 593, 8152, 262, 14658, 31905, 1446, 1137, Array(380, dtype=int32)]\n", "new_input_str: Life is like riding a bicycle. To keep your\n", "next_logits: [-153.86253 -133.40291 -153.86255 ... -153.86249 -153.86252 -153.86255]\n" ] } ], "source": [ "# Single token-generation:\n", "# printing all details about intermediate tokens, logits and\n", "# final next token predicted.\n", "prompt = \"Life is like riding a bicycle. To keep\"\n", "max_seq_len = 8\n", "_ = decode_lm_response(\n", " vocab=vocab,\n", " inference_runner=inference_runner,\n", " contexts_list=[prompt],\n", " max_seq_len=max_seq_len,\n", " print_output=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " your balance, you must keep moving.\\n—: 95%|█████████▌| 19/20 [01:03<00:03, 3.32s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Life is like riding a bicycle. To keep your balance, you must keep moving.\n", "— Albert\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Generate multiple tokens sequentially.\n", "# This is quite slow+naive, but helps to understand the autoregressive setup.\n", "original_prompt = \"Life is like riding a bicycle. To keep\"\n", "max_seq_len = 20\n", "\n", "prompt = original_prompt\n", "pbar = tqdm(total=max_seq_len)\n", "while (n_tokens := len(vocab.encode(_preprocess_text(prompt)))) < max_seq_len:\n", " pbar.update(n_tokens - pbar.n)\n", " pbar.set_description(prompt[len(original_prompt):].replace(\"\\n\", \"\\\\n\"))\n", " lm_response = decode_lm_response(\n", " vocab=vocab,\n", " inference_runner=inference_runner,\n", " contexts_list=[prompt],\n", " max_seq_len=max_seq_len,\n", " print_output=False,\n", " )\n", " prompt = lm_response[0][\"new_input_str\"]\n", "pbar.close()\n", "print()\n", "print(prompt)\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "sigmoid-attn-py310", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.17" } }, "nbformat": 4, "nbformat_minor": 2 }