|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Pl@ntNet aggregation strategy\n", |
| 8 | + "\n", |
| 9 | + "This aggregation strategy presented in [this paper](https://hal.science/hal-04603038) models the expertise of users the number of labels they correctly interact with.\n", |
| 10 | + "\n", |
| 11 | + "Let us create a toy-dataset to run it with 4 users, 20 items and 9 classes.\n", |
| 12 | + "\n", |
| 13 | + "The full Pl@ntNet-CrowdSWE dataset is available [on zenodo](https://zenodo.org/records/10782465) with more than 6.5M items, 850K users and 11K classes.\n", |
| 14 | + "\n", |
| 15 | + "Each item (*e.g* a plant observation) has been labeled by at least a single user. The ground truth is simulated, so everything is known to measure the accuracy (amongst other metrics). Each item has an authoring user (the picture is taken and uploaded by a user). In the algorithm authoring users and users that vote on others' items are treated differently. " |
| 16 | + ] |
| 17 | + }, |
| 18 | + { |
| 19 | + "cell_type": "code", |
| 20 | + "execution_count": null, |
| 21 | + "metadata": {}, |
| 22 | + "outputs": [], |
| 23 | + "source": [ |
| 24 | + "import numpy as np\n", |
| 25 | + "import peerannot.models as pmod\n", |
| 26 | + "from tqdm.auto import tqdm\n", |
| 27 | + "import matplotlib.pyplot as plt\n", |
| 28 | + "\n", |
| 29 | + "# Crowdsourced answers (are stored typically in a .json file)\n", |
| 30 | + "votes = {\n", |
| 31 | + " 0: {0: 2, 1: 2, 2: 2},\n", |
| 32 | + " 1: {0: 6, 1: 2, 3: 2},\n", |
| 33 | + " 2: {1: 8, 2: 7, 3: 8},\n", |
| 34 | + " 3: {0: 1, 1: 1, 2: 5},\n", |
| 35 | + " 4: {2: 4},\n", |
| 36 | + " 5: {0: 0, 1: 0, 2: 1, 3: 6},\n", |
| 37 | + " 6: {1: 5, 3: 3},\n", |
| 38 | + " 7: {0: 3, 2: 6, 3: 4},\n", |
| 39 | + " 8: {1: 7, 3: 7},\n", |
| 40 | + " 9: {0: 8, 2: 1, 3: 1},\n", |
| 41 | + " 10: {0: 0, 1: 0, 2: 1},\n", |
| 42 | + " 11: {2: 3},\n", |
| 43 | + " 12: {0: 7, 2: 8, 3: 1},\n", |
| 44 | + " 13: {1: 3},\n", |
| 45 | + " 14: {0: 5, 2: 4, 3: 4},\n", |
| 46 | + " 15: {0: 5, 1: 7},\n", |
| 47 | + " 16: {0: 0, 1: 4, 3: 4},\n", |
| 48 | + " 17: {1: 5, 2: 7, 3: 7},\n", |
| 49 | + " 18: {0: 3},\n", |
| 50 | + " 19: {1: 7, 2: 7},\n", |
| 51 | + "}\n", |
| 52 | + "\n", |
| 53 | + "# Ground truth (gt) and authors of the observations\n", |
| 54 | + "authors = [0, 0, 1, 0, 2, 0, 1, 0, 3, 1, 1, 3, 0, 1, 0, 1, 0, 1, 0, 1]\n", |
| 55 | + "gt = [2, 6, 4, 1, 1, -1, 3, -1, 2, 8, 4, 1, 7, 0, 5, 5, 0, -1, 6, 7]\n", |
| 56 | + "np.savetxt(\"authors_toy.txt\", authors, fmt=\"%i\")" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "markdown", |
| 61 | + "metadata": {}, |
| 62 | + "source": [ |
| 63 | + "We will evaluate the performance of the method on two subsets:\n", |
| 64 | + "- The full dataset\n", |
| 65 | + "- The subset where the items have been voted on by more than two users \n", |
| 66 | + "We also monitor the proportion of classes retrieved after the aggregation compared to the ground truth (if a class is never predicted by the aggregation, a model can later never be trained to recognize it)." |
| 67 | + ] |
| 68 | + }, |
| 69 | + { |
| 70 | + "cell_type": "code", |
| 71 | + "execution_count": null, |
| 72 | + "metadata": {}, |
| 73 | + "outputs": [], |
| 74 | + "source": [ |
| 75 | + "def build_mask_more_than_two(answers, gt):\n", |
| 76 | + " mask = np.zeros(len(answers), dtype=bool)\n", |
| 77 | + " for tt in tqdm(answers.keys()):\n", |
| 78 | + " if len(answers[tt]) >= 2 and gt[int(tt)] != -1:\n", |
| 79 | + " mask[int(tt)] = 1\n", |
| 80 | + " return mask\n", |
| 81 | + "\n", |
| 82 | + "\n", |
| 83 | + "mask_more_than_two = build_mask_more_than_two(votes, gt)\n", |
| 84 | + "\n", |
| 85 | + "\n", |
| 86 | + "def build_mask_more_than_two(answers, gt):\n", |
| 87 | + " mask = np.zeros(len(answers), dtype=bool)\n", |
| 88 | + " for tt in tqdm(answers.keys()):\n", |
| 89 | + " if len(answers[tt]) >= 2 and gt[int(tt)] != -1:\n", |
| 90 | + " mask[int(tt)] = 1\n", |
| 91 | + " return mask\n", |
| 92 | + "\n", |
| 93 | + "\n", |
| 94 | + "mask_more_than_two = build_mask_more_than_two(votes, gt)\n", |
| 95 | + "\n", |
| 96 | + "# %% Metric to compare the strategies where the ground truth is available (proportion of classes kept and accuracy)\n", |
| 97 | + "\n", |
| 98 | + "\n", |
| 99 | + "def vol_class_kept(preds, truth, mask):\n", |
| 100 | + " uni_test = np.unique(truth[mask])\n", |
| 101 | + " n_class_test = uni_test.shape[0]\n", |
| 102 | + " preds_uni = np.unique(preds[mask])\n", |
| 103 | + " if preds_uni[0] == -1:\n", |
| 104 | + " preds_uni = preds_uni[1:]\n", |
| 105 | + " n_class_pred = preds_uni.shape[0]\n", |
| 106 | + " n_common = len(set(preds_uni).intersection(set(uni_test)))\n", |
| 107 | + " vol_kept = n_common / n_class_test * 100\n", |
| 108 | + " return n_class_pred, n_class_test, vol_kept\n", |
| 109 | + "\n", |
| 110 | + "\n", |
| 111 | + "def accuracy(preds, truth, mask):\n", |
| 112 | + " return np.mean(preds[mask] == truth[mask])\n", |
| 113 | + "\n", |
| 114 | + "\n", |
| 115 | + "# %% Metric to compare the strategies where the ground truth is available (proportion of classes kept and accuracy)\n", |
| 116 | + "\n", |
| 117 | + "\n", |
| 118 | + "def vol_class_kept(preds, truth, mask):\n", |
| 119 | + " uni_test = np.unique(truth[mask])\n", |
| 120 | + " n_class_test = uni_test.shape[0]\n", |
| 121 | + " preds_uni = np.unique(preds[mask])\n", |
| 122 | + " if preds_uni[0] == -1:\n", |
| 123 | + " preds_uni = preds_uni[1:]\n", |
| 124 | + " n_class_pred = preds_uni.shape[0]\n", |
| 125 | + " n_common = len(set(preds_uni).intersection(set(uni_test)))\n", |
| 126 | + " vol_kept = n_common / n_class_test * 100\n", |
| 127 | + " return n_class_pred, n_class_test, vol_kept\n", |
| 128 | + "\n", |
| 129 | + "\n", |
| 130 | + "def accuracy(preds, truth, mask):\n", |
| 131 | + " return np.mean(preds[mask] == truth[mask])" |
| 132 | + ] |
| 133 | + }, |
| 134 | + { |
| 135 | + "cell_type": "markdown", |
| 136 | + "metadata": {}, |
| 137 | + "source": [ |
| 138 | + "We now run the Pl@ntNet strategy against other strategies available in `peerannot`." |
| 139 | + ] |
| 140 | + }, |
| 141 | + { |
| 142 | + "cell_type": "markdown", |
| 143 | + "metadata": {}, |
| 144 | + "source": [ |
| 145 | + "Each strategy is first instanciated. The `.run` method is called if any optimization procedure is necessary. Estimated labels are recovered with the `.get_answers()` method." |
| 146 | + ] |
| 147 | + }, |
| 148 | + { |
| 149 | + "cell_type": "code", |
| 150 | + "execution_count": null, |
| 151 | + "metadata": { |
| 152 | + "vscode": { |
| 153 | + "languageId": "plaintext" |
| 154 | + } |
| 155 | + }, |
| 156 | + "outputs": [], |
| 157 | + "source": [ |
| 158 | + "mv = pmod.MV(answers=votes, n_classes=9, n_workers=4)\n", |
| 159 | + "yhat_mv = mv.get_answers()\n", |
| 160 | + "wawa = pmod.Wawa(answers=votes, n_classes=9, n_workers=4)\n", |
| 161 | + "wawa.run()\n", |
| 162 | + "yhat_wawa = wawa.get_answers()\n", |
| 163 | + "twothird = pmod.TwoThird(answers=votes, n_classes=9, n_workers=4)\n", |
| 164 | + "yhat_twothird = twothird.get_answers()\n", |
| 165 | + "\n", |
| 166 | + "# %% run the PlantNet aggregatio\n", |
| 167 | + "pn = pmod.PlantNet(\n", |
| 168 | + " answers=votes,\n", |
| 169 | + " n_classes=9,\n", |
| 170 | + " n_workers=4,\n", |
| 171 | + " alpha=0.5,\n", |
| 172 | + " beta=0.2,\n", |
| 173 | + " authors=\"authors_toy.txt\",\n", |
| 174 | + ")\n", |
| 175 | + "pn.run(maxiter=5, epsilon=1e-9)\n", |
| 176 | + "yhatpn = pn.get_answers()" |
| 177 | + ] |
| 178 | + }, |
| 179 | + { |
| 180 | + "cell_type": "markdown", |
| 181 | + "metadata": {}, |
| 182 | + "source": [ |
| 183 | + "Finally we plot the metrics considered" |
| 184 | + ] |
| 185 | + }, |
| 186 | + { |
| 187 | + "cell_type": "code", |
| 188 | + "execution_count": null, |
| 189 | + "metadata": { |
| 190 | + "vscode": { |
| 191 | + "languageId": "plaintext" |
| 192 | + } |
| 193 | + }, |
| 194 | + "outputs": [], |
| 195 | + "source": [ |
| 196 | + "# %% Compute the metrics for each strategy\n", |
| 197 | + "res_full = []\n", |
| 198 | + "res_more_than_two = []\n", |
| 199 | + "vol_class_full = []\n", |
| 200 | + "vol_class_more_than_two = []\n", |
| 201 | + "\n", |
| 202 | + "gt = np.array(gt)\n", |
| 203 | + "strats = [\"MV\", \"WAWA\", \"TwoThird\", \"PlantNet\"]\n", |
| 204 | + "\n", |
| 205 | + "for strat, res in zip(strats, [yhat_mv, yhat_wawa, yhat_twothird, yhatpn]):\n", |
| 206 | + " res_full.append(accuracy(res, gt, np.ones(len(gt), dtype=bool)))\n", |
| 207 | + " vol_class_full.append(vol_class_kept(res, gt, np.ones(len(gt), dtype=bool))[2])\n", |
| 208 | + " res_more_than_two.append(accuracy(res, gt, mask_more_than_two))\n", |
| 209 | + " vol_class_more_than_two.append(vol_class_kept(res, gt, mask_more_than_two)[2])\n", |
| 210 | + "# %% Plot the accuracy against the proportion of classes kept\n", |
| 211 | + "plt.figure()\n", |
| 212 | + "for i, strat in enumerate(strats):\n", |
| 213 | + " plt.scatter(vol_class_full[i], res_full[i], label=strat)\n", |
| 214 | + "plt.title(r\"Full dataset\")\n", |
| 215 | + "plt.ylabel(\"Accuracy\")\n", |
| 216 | + "plt.xlabel(\"Proportion of classes kept (%)\")\n", |
| 217 | + "plt.legend()\n", |
| 218 | + "plt.show()\n", |
| 219 | + "\n", |
| 220 | + "plt.figure()\n", |
| 221 | + "for i, strat in enumerate(strats):\n", |
| 222 | + " plt.scatter(vol_class_more_than_two[i], res_more_than_two[i], label=strat)\n", |
| 223 | + "plt.title(r\"Dataset with at least 2 annotations per observation\")\n", |
| 224 | + "plt.ylabel(\"Accuracy\")\n", |
| 225 | + "plt.xlabel(\"Proportion of classes kept (%)\")\n", |
| 226 | + "plt.legend()\n", |
| 227 | + "plt.show()" |
| 228 | + ] |
| 229 | + } |
| 230 | + ], |
| 231 | + "metadata": { |
| 232 | + "language_info": { |
| 233 | + "name": "python" |
| 234 | + } |
| 235 | + }, |
| 236 | + "nbformat": 4, |
| 237 | + "nbformat_minor": 2 |
| 238 | +} |
0 commit comments