|
32 | 32 | }, |
33 | 33 | { |
34 | 34 | "cell_type": "code", |
35 | | - "execution_count": null, |
| 35 | + "execution_count": 75, |
36 | 36 | "metadata": {}, |
37 | | - "outputs": [], |
| 37 | + "outputs": [ |
| 38 | + { |
| 39 | + "data": { |
| 40 | + "text/plain": [ |
| 41 | + "['Adelie', 'Chinstrap', 'Gentoo']" |
| 42 | + ] |
| 43 | + }, |
| 44 | + "execution_count": 75, |
| 45 | + "metadata": {}, |
| 46 | + "output_type": "execute_result" |
| 47 | + } |
| 48 | + ], |
38 | 49 | "source": [ |
39 | | - "from palmerpenguins import load_penguins" |
| 50 | + "from palmerpenguins import load_penguins\n", |
| 51 | + "\n", |
| 52 | + "data = load_penguins()\n", |
| 53 | + "\n", |
| 54 | + "data\n", |
| 55 | + "\n", |
| 56 | + "target_names = sorted(data.species.unique())\n", |
| 57 | + "\n", |
| 58 | + "target_names\n" |
40 | 59 | ] |
41 | 60 | }, |
42 | 61 | { |
|
55 | 74 | "source": [ |
56 | 75 | "### Task 2: creating a ``torch.utils.data.Dataset``\n", |
57 | 76 | "\n", |
| 77 | + "The penguin data reading and processing can be encapsulated in a PyTorch dataset class.\n", |
| 78 | + "\n", |
| 79 | + "- This is helpful because...\n", |
| 80 | + "\n", |
58 | 81 | "All PyTorch dataset objects are subclasses of the ``torch.utils.data.Dataset`` class. To make a custom dataset, create a class which inherits from the ``Dataset`` class, implement some methods (the Python magic (or dunder) methods ``__len__`` and ``__getitem__``) and supply some data.\n", |
59 | 82 | "\n", |
60 | | - "Spoiler alert: we've done this for you already in ``src/ml_workshop/_penguins.py``.\n", |
| 83 | + "Spoiler alert: we've done this for you already below (see ``src/ml_workshop/_penguins.py`` for a more sophisticated implementation)\n", |
61 | 84 | "\n", |
62 | 85 | "- Open the file ``src/ml_workshop/_penguins.py``.\n", |
63 | 86 | "- Let's examine, and discuss, each of the methods together.\n", |
|
75 | 98 | " - ``y_tfms``— ..." |
76 | 99 | ] |
77 | 100 | }, |
| 101 | + { |
| 102 | + "cell_type": "code", |
| 103 | + "execution_count": 108, |
| 104 | + "metadata": {}, |
| 105 | + "outputs": [], |
| 106 | + "source": [ |
| 107 | + "from typing import Optional, List, Dict, Tuple, Any\n", |
| 108 | + "\n", |
| 109 | + "# import pytorch functions necessary for transformations:\n", |
| 110 | + "from torch import tensor, float32, eye\n", |
| 111 | + "\n", |
| 112 | + "from torch.utils.data import Dataset\n", |
| 113 | + "from torchvision.transforms import Compose\n", |
| 114 | + "\n", |
| 115 | + "from pandas import DataFrame\n", |
| 116 | + "\n", |
| 117 | + "from palmerpenguins import load_penguins\n", |
| 118 | + "\n", |
| 119 | + "\n", |
| 120 | + "class PenguinDataset(Dataset):\n", |
| 121 | + " def __init__(\n", |
| 122 | + " self,\n", |
| 123 | + " input_keys: List[str],\n", |
| 124 | + " target_keys: List[str],\n", |
| 125 | + " train: bool,\n", |
| 126 | + " ):\n", |
| 127 | + " \"\"\"Build ``PenguinDataset``.\"\"\"\n", |
| 128 | + " self.input_keys = input_keys\n", |
| 129 | + " self.target_keys = target_keys\n", |
| 130 | + "\n", |
| 131 | + " data = load_penguins()\n", |
| 132 | + " data = (\n", |
| 133 | + " data.loc[~data.isna().any(axis=1)]\n", |
| 134 | + " .sort_values(by=sorted(data.keys()))\n", |
| 135 | + " .reset_index(drop=True)\n", |
| 136 | + " )\n", |
| 137 | + " # Transform the sex field into a float, with male represented by 1.0, female by 0.0\n", |
| 138 | + " data.sex = (data.sex == \"male\").astype(float)\n", |
| 139 | + " self.full_df = data\n", |
| 140 | + "\n", |
| 141 | + " valid_df = self.full_df.groupby(by=[\"species\", \"sex\"]).sample(\n", |
| 142 | + " n=10,\n", |
| 143 | + " random_state=123,\n", |
| 144 | + " )\n", |
| 145 | + " # The training items are simply the items *not* in the valid split\n", |
| 146 | + " train_df = self.full_df.loc[~self.full_df.index.isin(valid_df.index)]\n", |
| 147 | + "\n", |
| 148 | + " self.split = {\"train\": train_df, \"valid\": valid_df}[\"train\" if train is True else \"valid\"]\n", |
| 149 | + "\n", |
| 150 | + "\n", |
| 151 | + " def __len__(self) -> int:\n", |
| 152 | + " return len(self.split)\n", |
| 153 | + " \n", |
| 154 | + " def __getitem__(self, idx: int) -> Tuple[Any, Any]:\n", |
| 155 | + " # get the row index (idx) from the dataframe and \n", |
| 156 | + " # select relevant column features (provided as input_keys)\n", |
| 157 | + " feats = self.split.iloc[idx][self.input_keys]\n", |
| 158 | + "\n", |
| 159 | + " # this gives a 'species' i.e. one of ('Gentoo',), ('Chinstrap',), or ('Adelie',) \n", |
| 160 | + " tgts = self.split.iloc[idx][self.target_keys]\n", |
| 161 | + "\n", |
| 162 | + " # Exercise #1: convert the feats to PyTorch\n", |
| 163 | + " feats = tensor(feats.values, dtype=float32)\n", |
| 164 | + "\n", |
| 165 | + " # Exercise #2: convert this to a 'one-hot vector' \n", |
| 166 | + " target_names = sorted(self.full_df.species.unique())\n", |
| 167 | + " \n", |
| 168 | + " tgts = eye(len(target_names))[target_names.index(tgts.values[0])]\n", |
| 169 | + " \n", |
| 170 | + " return (feats, tgts)" |
| 171 | + ] |
| 172 | + }, |
78 | 173 | { |
79 | 174 | "cell_type": "markdown", |
80 | 175 | "metadata": {}, |
|
93 | 188 | }, |
94 | 189 | { |
95 | 190 | "cell_type": "code", |
96 | | - "execution_count": null, |
| 191 | + "execution_count": 109, |
97 | 192 | "metadata": {}, |
98 | | - "outputs": [], |
| 193 | + "outputs": [ |
| 194 | + { |
| 195 | + "data": { |
| 196 | + "text/plain": [ |
| 197 | + "(tensor([ 42.9000, 5000.0000]), tensor([0., 0., 1.]))" |
| 198 | + ] |
| 199 | + }, |
| 200 | + "execution_count": 109, |
| 201 | + "metadata": {}, |
| 202 | + "output_type": "execute_result" |
| 203 | + } |
| 204 | + ], |
99 | 205 | "source": [ |
100 | | - "from ml_workshop import PenguinDataset\n", |
| 206 | + "# from ml_workshop import PenguinDataset\n", |
101 | 207 | "\n", |
102 | | - "data_set = PenguinDataset(\n", |
| 208 | + "data_set_1 = PenguinDataset(\n", |
103 | 209 | " input_keys=[\"bill_length_mm\", \"body_mass_g\"],\n", |
104 | 210 | " target_keys=[\"species\"],\n", |
105 | 211 | " train=True,\n", |
106 | 212 | ")\n", |
107 | 213 | "\n", |
108 | 214 | "\n", |
109 | | - "for features, target in data_set:\n", |
110 | | - " # print the features and targets here\n", |
111 | | - " pass" |
| 215 | + "# for features, target in data_set:\n", |
| 216 | + "# # print the features and targets here\n", |
| 217 | + "# print(features, target)\n", |
| 218 | + "\n", |
| 219 | + "\n", |
| 220 | + "data_set_1[0]" |
112 | 221 | ] |
113 | 222 | }, |
114 | 223 | { |
|
417 | 526 | "name": "python", |
418 | 527 | "nbconvert_exporter": "python", |
419 | 528 | "pygments_lexer": "ipython3", |
420 | | - "version": "3.11.4" |
| 529 | + "version": "3.12.4" |
421 | 530 | } |
422 | 531 | }, |
423 | 532 | "nbformat": 4, |
|
0 commit comments