{ "cells": [ { "cell_type": "markdown", "id": "4936d1e6-5e7d-4e22-ae35-8e888927ce2d", "metadata": {}, "source": [ "# Use Pre-trained CNN as feature extractor" ] }, { "cell_type": "markdown", "id": "bf9e9fb5-7383-475a-93e1-decdbd59c247", "metadata": {}, "source": [ "Use MobileNetv3 as a feature extractor via the [embetter](https://github.com/koaning/embetter) scikit-learn library and [timm](https://github.com/rwightman/pytorch-image-models). Train a logistic regression classifier in scikit-learn on the embeddings." ] }, { "cell_type": "markdown", "id": "96b717c7-54c9-40dc-ba80-0fb47da2c0bd", "metadata": {}, "source": [ "![](images/feature-extractor.png)" ] }, { "cell_type": "code", "execution_count": 1, "id": "64d1dd64-c45b-4092-84d1-1bfcd0998f15", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "# pip install gitpython\n", "from git import Repo\n", "\n", "if not os.path.exists(\"mnist-pngs\"):\n", " Repo.clone_from(\"https://github.com/rasbt/mnist-pngs\", \"mnist-pngs\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "3a892538-8d9b-4420-9525-26d1a4b37ae3", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "\n", "for name in (\"train\", \"test\"):\n", "\n", " df = pd.read_csv(f\"mnist-pngs/{name}.csv\")\n", " df[\"filepath\"] = df[\"filepath\"].apply(lambda x: \"mnist-pngs/\" + x)\n", " df = df.sample(frac=1, random_state=123).reset_index(drop=True)\n", " df.to_csv(f\"mnist-pngs/{name}_shuffled.csv\", index=None)" ] }, { "cell_type": "code", "execution_count": 3, "id": "5885e9bb-d43f-46ca-83ae-e2d63edcbb37", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1fba0fcb2b1f408f85013da0d1694dd3", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/60 [00:00