{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "authorship_tag": "ABX9TyMc59DuqL4h71ziHdKTBAFW" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "# PyTorch进程之间的共享 Tensor(CUDA IPC)\n", "\n", "利用CUDA IPC技术,实现进程之间共享Tensor,避免了数据拷贝,也能降低显存的消耗。常见的使用场景:KV cache在节点内的数据共享、RL训/推共卡权重数据的共享。\n", "\n", "包含两个用例:\n", "1. 单个脚本启动多进程共享数据,包括显式方式传递信息;\n", "2. 两个脚本一收一发,完成Tensor数据共享。\n", "\n", "Author: kaiyuan\n", "\n", "Email: kaiyuanxie@yeah.net\n", "\n" ], "metadata": { "id": "zcx1OTZ9f40L" } }, { "cell_type": "markdown", "source": [ "# 1 单脚本多进程示例\n", "\n", "PyTorch下通过CUDA IPC在进程间传递GPU Tensor的示例。\n", "\n", "说明\n", "\n", "1. 推荐方式:使用 ``torch.multiprocessing`` + ``spawn``,通过 ``Queue`` / ``Pipe``\n", " 直接传递 ``cuda`` Tensor。序列化时会走与pickle相同的reduction,底层对\n", " ``UntypedStorage`` 调用 ``_share_cuda_()`` / ``_new_shared_cuda()``,即 CUDA IPC。\n", "\n", "2. 显式方式:手动从 ``TypedStorage._share_cuda_()`` 取出元数据,再在子进程里调用\n", " 内部的 ``rebuild_cuda_tensor`` 还原Tensor(与 PyTorch 多进程序列化逻辑一致),便于对照到底传了哪些IPC字段。依赖非公开API,仅作学习参考。\n", "\n", "运行前请确认本机已安装带CUDA的PyTorch,且 ``torch.cuda.is_available()`` 为 True。\n", "\n", "Windows/Linux均需在子进程中使用 ``spawn``;CUDA IPC一般用于同一台机器上的进程。\n", "\n", "若在NFS等网络盘上运行脚本时出现 ``SemLock`` / ``FileNotFoundError``,除本文件在 Linux下默认\n", "``TMPDIR=/tmp`` 外,仍可在shell中 ``export TMPDIR=/tmp``再启动。\n", "\n", "Queue示例采用**三个独立Queue**:设备号、tensor、回执分开传递。若在同一Queue里先后``put``\n", "``int``、tensor、再 ``get`` 回执,在 ``spawn`` + 内部feeder线程下可能出现**取到对象类型与预期不符**(例如第一次 ``get`` 到 tensor、回执读到 ``0``)。分队列可从根上避免混序。\n", "\n", "示例运行环境:\n", "\n", "docker pull nvcr.io/nvidia/pytorch:26.01-py3\n" ], "metadata": { "id": "C4rbqwN7gw5N" } }, { "cell_type": "code", "source": [ "from __future__ import annotations\n", "\n", "# 必须在 import multiprocessing 之前设置:NFS 上常见 SemLock FileNotFoundError\n", "import os\n", "import sys\n", "\n", "if sys.platform.startswith(\"linux\"):\n", " os.environ.setdefault(\"TMPDIR\", \"/tmp\")\n", "\n", "from multiprocessing.connection import Connection\n", "\n", "import torch\n", "import torch.multiprocessing as mp\n", "\n", "\n", "def _cuda_startup(device_index: int = 0) -> None:\n", " \"\"\"\n", " 在子进程里、对 CUDA tensor 做 unpickle / rebuild **之前**调用。\n", " 顺序:先 set_device,再 init;部分驱动下顺序反了会报 invalid device context。\n", " 再做一次极小 tensor 分配,确保当前 device 上 context 已完全就绪。\n", " \"\"\"\n", " torch.cuda.set_device(device_index)\n", " if hasattr(torch.cuda, \"init\"):\n", " torch.cuda.init()\n", " else:\n", " torch.cuda._lazy_init() # type: ignore[attr-defined]\n", " _ = torch.empty((1,), dtype=torch.uint8, device=f\"cuda:{device_index}\")\n", " torch.cuda.synchronize()\n", "\n", "\n", "def _ensure_spawn() -> None:\n", " try:\n", " mp.set_start_method(\"spawn\", force=True)\n", " except RuntimeError:\n", " pass\n", "\n", "\n", "# ---------------------------------------------------------------------------\n", "# 方式一:Queue 传递 CUDA Tensor(推荐,与日常训练 DataLoader 等一致)\n", "# ---------------------------------------------------------------------------\n", "\n", "\n", "def _receiver_queue(q_dev: mp.Queue, q_tensor: mp.Queue, q_ack: mp.Queue) -> None:\n", " raw = q_dev.get()\n", " if not isinstance(raw, int):\n", " raise TypeError(\n", " f\"expected int CUDA device index on q_dev, got {type(raw).__name__}\"\n", " )\n", " device_idx = raw\n", " _cuda_startup(device_idx)\n", " t: torch.Tensor = q_tensor.get()\n", " print(\n", " f\"[receiver] 收到 tensor: shape={tuple(t.shape)} \"\n", " f\"device={t.device} dtype={t.dtype}\"\n", " )\n", " print(f\"[receiver] mean={float(t.mean()):.6f}\")\n", " # 与发送方共享同一块显存;发送方仍持有引用时,此处原地修改对发送方可见\n", " t.add_(1.0)\n", " q_ack.put(\"ok\")\n", "\n", "\n", "def run_queue_demo() -> None:\n", " \"\"\"主进程创建 Tensor 放入 Queue,子进程取出并原地加 1。\"\"\"\n", " if not torch.cuda.is_available():\n", " print(\"CUDA 不可用,跳过 Queue 示例。\", file=sys.stderr)\n", " return\n", "\n", " _ensure_spawn()\n", " ctx = mp.get_context(\"spawn\")\n", " q_dev: mp.Queue = ctx.Queue()\n", " q_tensor: mp.Queue = ctx.Queue()\n", " q_ack: mp.Queue = ctx.Queue()\n", " p = ctx.Process(target=_receiver_queue, args=(q_dev, q_tensor, q_ack))\n", " p.start()\n", "\n", " _cuda_startup(0)\n", " x = torch.randn(4, 5, device=\"cuda\", dtype=torch.float32).detach()\n", " dev = int(x.device.index if x.device.index is not None else 0)\n", " print(f\"[sender] 发送前 mean={float(x.mean()):.6f} device=cuda:{dev}\")\n", " q_dev.put(dev)\n", " q_tensor.put(x)\n", " ack = q_ack.get()\n", " print(f\"[sender] 子进程回执: {ack}\")\n", " print(f\"[sender] 接收端 add_(1) 后的 mean={float(x.mean()):.6f}\")\n", " p.join()\n", "\n", "\n", "# ---------------------------------------------------------------------------\n", "# 方式二:显式 _share_cuda_ 元数据 + rebuild_cuda_tensor(对照内部实现)\n", "# ---------------------------------------------------------------------------\n", "\n", "\n", "def _receiver_explicit(conn: Connection) -> None:\n", " from torch.multiprocessing.reductions import rebuild_cuda_tensor\n", "\n", " _cuda_startup(0)\n", " payload = conn.recv()\n", " t = rebuild_cuda_tensor(*payload)\n", " print(\n", " f\"[receiver/explicit] 还原 tensor: shape={tuple(t.shape)} \"\n", " f\"sum={float(t.sum()):.4f}\"\n", " )\n", " conn.send(\"done\")\n", "\n", "\n", "def run_explicit_ipc_demo() -> None:\n", " \"\"\"\n", " 主进程手动调用 ``storage._share_cuda_()``,把与 ``reduce_tensor`` 相同的参数包\n", " 发给子进程,用 ``rebuild_cuda_tensor`` 还原。发送后须保持原 tensor 存活直至\n", " 子进程完成打开 IPC(此处用 Pipe 同步)。\n", " \"\"\"\n", " if not torch.cuda.is_available():\n", " print(\"CUDA 不可用,跳过显式 IPC 示例。\", file=sys.stderr)\n", " return\n", "\n", " from torch.multiprocessing.reductions import (\n", " StorageWeakRef,\n", " rebuild_cuda_tensor,\n", " shared_cache,\n", " )\n", "\n", " _ensure_spawn()\n", " ctx = mp.get_context(\"spawn\")\n", " parent_conn, child_conn = ctx.Pipe()\n", "\n", " p = ctx.Process(target=_receiver_explicit, args=(child_conn,))\n", " p.start()\n", "\n", " _cuda_startup(0)\n", " tensor = torch.arange(12, dtype=torch.float32, device=\"cuda\").view(3, 4).detach()\n", " storage = tensor._typed_storage()\n", " (\n", " device,\n", " handle,\n", " storage_size_bytes,\n", " storage_offset_bytes,\n", " ref_counter_handle,\n", " ref_counter_offset,\n", " event_handle,\n", " event_sync_required,\n", " ) = storage._share_cuda_()\n", " # 与 torch.multiprocessing.reductions.reduce_tensor 中 CUDA 分支一致\n", " shared_cache[handle] = StorageWeakRef(storage)\n", "\n", " payload = (\n", " type(tensor),\n", " tensor.size(),\n", " tensor.stride(),\n", " tensor.storage_offset(),\n", " type(storage),\n", " tensor.dtype,\n", " device,\n", " handle,\n", " storage_size_bytes,\n", " storage_offset_bytes,\n", " tensor.requires_grad,\n", " ref_counter_handle,\n", " ref_counter_offset,\n", " event_handle,\n", " event_sync_required,\n", " )\n", " parent_conn.send(payload)\n", " parent_conn.recv() # 等子进程 rebuild 完成后再退出作用域\n", " p.join()\n", "\n", "\n", "def main() -> None:\n", " print(\"=== 1) torch.multiprocessing.Queue 传递 CUDA Tensor ===\\n\")\n", " run_queue_demo()\n", " print(\"\\n=== 2) 显式 _share_cuda_ + rebuild_cuda_tensor ===\\n\")\n", " run_explicit_ipc_demo()\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "A_ZOYbQ2jX8N", "outputId": "23ae8f07-668d-4910-983f-55b1ca747d11" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "=== 1) torch.multiprocessing.Queue 传递 CUDA Tensor ===\n", "\n", "[sender] 发送前 mean=0.160294 device=cuda:0\n", "[receiver] 收到 tensor: shape=(4, 5) device=cuda:0 dtype=torch.float32\n", "[receiver] mean=0.160294\n", "[sender] 子进程回执: ok\n", "[sender] 接收端 add_(1) 后的 mean=1.160294\n", "\n", "=== 2) 显式 _share_cuda_ + rebuild_cuda_tensor ===\n", "\n", "[receiver/explicit] 还原 tensor: shape=(3, 4) sum=66.0000\n", "\n" ] } ] }, { "cell_type": "markdown", "source": [ "# 2 跨脚本运行示例\n", "\n", "跨脚本(两个独立 Python 进程)共享 CUDA Tensor 的示例。\n", "\n", "\n", "运行前设置环境变量:# export TMPDIR=/tmp\n", "\n", "运行方式(开两个终端):\n", "1) 先启动接收端:\n", " python cuda_ipc_cross_script_demo.py --mode receiver --host 127.0.0.1 --port 29531\n", "\n", "2) 再启动发送端:\n", " python cuda_ipc_cross_script_demo.py --mode sender --host 127.0.0.1 --port 29531 --device 0\n", "\n", "预期现象:\n", "- 接收端收到 CUDA tensor 后执行 t.add_(1.0)\n", "- 发送端在收到 ack 后,能看到本地 x 的均值同步变化(说明两进程共享同一块显存)\n" ], "metadata": { "id": "dgrn81UUgxea" } }, { "cell_type": "code", "source": [ "from __future__ import annotations\n", "\n", "import argparse\n", "import os\n", "import sys\n", "import time\n", "from multiprocessing.connection import Client, Listener\n", "from typing import Any\n", "\n", "if sys.platform.startswith(\"linux\"):\n", " os.environ.setdefault(\"TMPDIR\", \"/tmp\")\n", "\n", "import torch\n", "import torch.multiprocessing.reductions as mp_reductions\n", "\n", "\n", "def _cuda_startup(device_index: int) -> None:\n", " torch.cuda.set_device(device_index)\n", " if hasattr(torch.cuda, \"init\"):\n", " torch.cuda.init()\n", " else:\n", " torch.cuda._lazy_init() # type: ignore[attr-defined]\n", " _ = torch.empty((1,), dtype=torch.uint8, device=f\"cuda:{device_index}\")\n", " torch.cuda.synchronize()\n", "\n", "\n", "def run_receiver(host: str, port: int) -> None:\n", " listener = Listener((host, port), authkey=b\"cuda-ipc-demo\")\n", " print(f\"[receiver] listening at {host}:{port}\")\n", " conn = listener.accept()\n", " print(\"[receiver] sender connected\")\n", "\n", " first_msg: Any = conn.recv()\n", " if not isinstance(first_msg, int):\n", " raise TypeError(f\"expected device index int, got {type(first_msg).__name__}\")\n", " device_idx = first_msg\n", " _cuda_startup(device_idx)\n", "\n", " t: torch.Tensor = conn.recv()\n", " print(\n", " f\"[receiver] got tensor: shape={tuple(t.shape)} device={t.device} \"\n", " f\"mean={float(t.mean()):.6f}\"\n", " )\n", " t.add_(1.0)\n", " torch.cuda.synchronize()\n", " conn.send({\"status\": \"ok\", \"receiver_mean\": float(t.mean())})\n", " conn.close()\n", " listener.close()\n", " print(\"[receiver] done\")\n", "\n", "\n", "def run_sender(host: str, port: int, device: int) -> None:\n", " if not torch.cuda.is_available():\n", " raise RuntimeError(\"CUDA is not available\")\n", "\n", " _cuda_startup(device)\n", " x = torch.randn(4, 5, device=f\"cuda:{device}\", dtype=torch.float32).detach()\n", " before = float(x.mean())\n", " print(f\"[sender] before send mean={before:.6f}, device={x.device}\")\n", "\n", " conn = Client((host, port), authkey=b\"cuda-ipc-demo\")\n", " conn.send(device)\n", " conn.send(x)\n", " ack = conn.recv()\n", " torch.cuda.synchronize()\n", " after = float(x.mean())\n", " print(f\"[sender] ack={ack}\")\n", " print(f\"[sender] after receiver add_ mean={after:.6f}\")\n", " conn.close()\n", "\n", " delta = after - before\n", " print(f\"[sender] mean delta={delta:.6f}\")\n", " if delta > 0.5:\n", " print(\"[sender] shared CUDA memory verified\")\n", " else:\n", " print(\"[sender] WARNING: delta is small; verify synchronization/environment\")\n", " time.sleep(0.2)\n", "\n", "\n", "def parse_args() -> argparse.Namespace:\n", " parser = argparse.ArgumentParser(description=\"Cross-script CUDA IPC tensor sharing demo\")\n", " parser.add_argument(\"--mode\", choices=[\"sender\", \"receiver\"], required=True)\n", " parser.add_argument(\"--host\", default=\"127.0.0.1\")\n", " parser.add_argument(\"--port\", type=int, default=29531)\n", " parser.add_argument(\"--device\", type=int, default=0)\n", " return parser.parse_args()\n", "\n", "\n", "def main() -> None:\n", " # 显式初始化 torch 的 reducer,确保 Connection.send/recv 处理 CUDA tensor 时走 IPC 分支\n", " mp_reductions.init_reductions()\n", " args = parse_args()\n", " if args.mode == \"receiver\":\n", " run_receiver(args.host, args.port)\n", " else:\n", " run_sender(args.host, args.port, args.device)\n", "\n", "\n", "if __name__ == \"__main__\":\n", " main()\n" ], "metadata": { "id": "uCnO8yuFgyL5" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "输出内容:\n", "\n", "```\n", "[sender] before send mean=0.173753, device=cuda:0\n", "[sender] ack={'status': 'ok', 'receiver_mean': 1.1737531423568726}\n", "[sender] after receiver add_ mean=1.173753\n", "[sender] mean delta=1.000000\n", "[sender] shared CUDA memory verified\n", "\n", "\n", "[receiver] listening at 127.0.0.1:29531\n", "[receiver] sender connected\n", "[receiver] got tensor: shape=(4, 5) device=cuda:0 mean=0.173753\n", "[receiver] done\n", "\n", "```" ], "metadata": { "id": "2QRnAtUzh96D" } } ] }