{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "The notebook shows how to use our tokenizer to encoder a point cloud to latent,\n", "and then decode the latent to 3dgs." ] }, { "cell_type": "code", "execution_count": null, "id": "1", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import math\n", "import os\n", "import typing as T\n", "\n", "import numpy as np\n", "import open3d as o3d\n", "\n", "import torch\n", "\n", "from lito.trainers import lito_trainer\n", "from plibs import data_utils, gs_utils, o3d_utils, sh_utils, structures, utils\n", "\n", "# assume you run the notebook in the notebooks folder\n", "repo_root = os.path.abspath(\"..\")\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " raise NotImplementedError(\"only support GPU currently\")" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "# we assume the tokenizer checkpoint is already downloaded.\n", "tokenizer_checkpoint_filename = \"https://ml-site.cdn-apple.com/models/lito/lito_new.ckpt\" # download it, or can also be a local path if already downloaded\n", "\n", "from lito.eval_scripts.st_model_utils import download_checkpoint, is_http_url\n", "\n", "if is_http_url(tokenizer_checkpoint_filename):\n", " tokenizer_checkpoint_filename = download_checkpoint(\n", " url=tokenizer_checkpoint_filename,\n", " download_dir_root=os.path.join(repo_root, \"artifacts\"),\n", " overwrite=False,\n", " )\n", "\n", "model: lito_trainer.LightTokenizationTrainer = lito_trainer.LightTokenizationTrainer.load_from_checkpoint(\n", " checkpoint_path=tokenizer_checkpoint_filename,\n", " map_location=device,\n", " strict=False, # we might not have lpips\n", ")\n", "model.freeze()\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "# load data (see `render_data.ipynb` to see how the data was created)\n", "data_filename = \"assets/bunny.npz\"\n", "pdict = dict(np.load(data_filename, allow_pickle=True))\n", "point_xyz_w = torch.from_numpy(pdict[\"point_xyz_w\"]).to(device=device, dtype=torch.float) # (n, 3) [-1, 1]\n", "point_rgb = torch.from_numpy(pdict[\"point_rgb\"]).to(device=device, dtype=torch.float) # (n, 3) uint8\n", "point_view_dir = torch.from_numpy(pdict[\"point_view_dir\"]).to(\n", " device=device, dtype=torch.float\n", ") # (n, 3) uint8, from pinhole to point, unnormalized\n", "\n", "\n", "# rgb and view_dir are saved as uint8 in the npz\n", "point_rgb = point_rgb / 255 # (n, 3) [0, 1]\n", "point_view_dir = torch.nn.functional.normalize(\n", " point_view_dir * (2 / 255) - 1,\n", " dim=-1,\n", ") # (n, 3), from pinhole to point" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "# encode point cloud to latent\n", "num_points = int(2**20) # int(2 ** 18) # the model is robust to fewer / more points\n", "num_latent = 8192 # 4096 # 16384 # the model is robust to slightly fewer or more latents\n", "\n", "with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):\n", " latent = model.get_latents(\n", " xyz_w=point_xyz_w[None, :num_points], # (b=1, n, 3xyz_w)\n", " rgb=point_rgb[None, :num_points], # (b=1, n, 3rgb) [0, 1]\n", " ray_origin_direction_w=torch.cat(\n", " [\n", " point_xyz_w[None, :num_points],\n", " point_view_dir[None, :num_points],\n", " ],\n", " dim=-1,\n", " ), # (b=1, n, 6)\n", " num_latent=num_latent,\n", " )[\"latent_tokens\"] # (b, num_latent, dim_latent)\n", " print(f\"latent: {latent.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "# decode to 3dgs\n", "result_dir = os.path.abspath(\"recon_results\")\n", "with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):\n", " print(f\"decoding gaussian..\")\n", " gs_dicts = model.inference_estimate_gaussians(\n", " fpoint_latent=latent, # (b, num_latent, dim_latent)\n", " init_coord_src=\"voxel_decoder\",\n", " ) # list of (b=1,)\n", " print(f\"Finished\")\n", "\n", " # save to ply\n", " print(f\"saving gaussian..\")\n", " filename = os.path.join(result_dir, f\"reconstructed_bunny_3dgs_nl{num_latent}_np{num_points}.ply\")\n", " ib = 0\n", " _sh_degree = sh_utils.get_sh_degree_from_total_dim(gs_dicts[ib][\"rgb_sh\"].size(-2))\n", " ngs = math.prod(gs_dicts[ib][\"xyz_w\"].shape[:-1])\n", " gs = gs_utils.Gaussians(\n", " sh_degree=_sh_degree,\n", " xyz_w=gs_dicts[ib][\"xyz_w\"].reshape(ngs, 3), # (n, 3xyz)\n", " rgb_sh=gs_dicts[ib][\"rgb_sh\"].reshape(ngs, -1, 3),\n", " rgb_sh_dc=None,\n", " rgb_sh_rest=None,\n", " scaling_logit=None,\n", " quaternion_prenorm=None,\n", " opacity_logit=None,\n", " scaling=gs_dicts[ib][\"scaling\"].reshape(ngs, 3), # (n, 3xyz)\n", " quaternion=gs_dicts[ib][\"quaternion\"].reshape(ngs, 4), # (n, 4xyzw)\n", " opacity=gs_dicts[ib][\"opacity\"].reshape(ngs, 1), # (n, 1)\n", " min_scaling=0, # handled by network\n", " scaling_activation_type=\"none\",\n", " )\n", " gs.save_ply(filename=filename)\n", " print(\n", " f\"saved 3dgs to {filename}, can use supersplat https://superspl.at/editor or sparkjs https://sparkjs.dev/viewer/ to view\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "# decode mesh\n", "with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):\n", " print(f\"decoding mesh..\")\n", " raw_meshes: T.List[structures.RawMesh] = model.inference_estimate_mesh(\n", " fpoint_latent=latent, # (b, num_latent, dim_latent)\n", " init_coord_src=\"voxel_decoder\",\n", " ) # list of (b=1,)\n", " print(f\"finished..\")\n", "\n", " # save to ply\n", " print(f\"saving mesh..\")\n", " ib = 0\n", " o3d_mesh = raw_meshes[ib].get_o3d_mesh()\n", " filename = os.path.join(result_dir, f\"reconstructed_bunny_mesh_nl{num_latent}_np{num_points}.ply\")\n", " o3d.io.write_triangle_mesh(filename, o3d_mesh)\n", " print(f\"saved mesh to {filename}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7", "metadata": {}, "outputs": [], "source": [ "# sample point cloud\n", "with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):\n", " print(f\"sampling points..\")\n", " num_points_to_sample = 16384\n", " init_noise_dict = model.get_conditional_sampling_init_noise(1, num_points_to_sample)\n", " sampled_x_flow_dict = model.conditional_sampling(\n", " fpoint_latent=latent, # (b, num_latent, dim_latent)\n", " num_steps=100,\n", " **init_noise_dict,\n", " method=\"heun\",\n", " latent_coord=None, # (bl, dn) packed or None\n", " ) # (b, num_points_to_sample, d)\n", " sampled_xyz_w = sampled_x_flow_dict[\"xyz_w\"].float() # (b, num_points_to_sample, 3xyz_w)\n", "\n", " # save to ply\n", " print(f\"saving points..\")\n", " ib = 0\n", " o3d_pcd = o3d_utils.creat_pcd(\n", " points=sampled_xyz_w[ib],\n", " color=((init_noise_dict[\"init_xyz_w\"][ib] + 1) * 0.5), # use init uvw as color\n", " )\n", " filename = os.path.join(result_dir, f\"reconstructed_bunny_pcd_nl{num_latent}_np{num_points}.ply\")\n", " o3d.io.write_point_cloud(filename, o3d_pcd)\n", " print(f\"saved points to {filename}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }