Skip to content

Commit 7c82f29

Browse files
committed
add plantnet strategy tutorial
1 parent eaf380c commit 7c82f29

File tree

3 files changed

+247
-2
lines changed

3 files changed

+247
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,4 @@ datasets/bird_audio/bird_sound_training_data/
218218
doc/_build/
219219

220220
doc/interface_api/generated/
221+
doc/tutorials/authors_toy.txt

doc/tutorials/index.rst

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Tutorials
22
=========
33

4-
These tutorials follows part of the `paper published paper in Computo Journal <https://computo.sfds.asso.fr/published-202402-lefort-peerannot/>`_. Feel free to check it out for more details and examples.
5-
4+
These tutorials follow part of the `paper published paper in Computo Journal <https://computo.sfds.asso.fr/published-202402-lefort-peerannot/>`_. Feel free to check it out for more details and examples.
5+
The `Pl@ntNet <https://plantnet.org/>`_ aggregation paper is available `here <https://hal.science/hal-04603038>`_.
66

77

88
.. toctree::
@@ -22,3 +22,9 @@ These tutorials follows part of the `paper published paper in Computo Journal <h
2222
:caption: Full pipeline on a practical example:
2323

2424
run_pipeline_notebook
25+
26+
.. toctree::
27+
:maxdepth: 1
28+
:caption: How to run the Pl@ntNet strategy
29+
30+
run_plantnet_strategy
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Pl@ntNet aggregation strategy\n",
8+
"\n",
9+
"This aggregation strategy presented in [this paper](https://hal.science/hal-04603038) models the expertise of users the number of labels they correctly interact with.\n",
10+
"\n",
11+
"Let us create a toy-dataset to run it with 4 users, 20 items and 9 classes.\n",
12+
"\n",
13+
"The full Pl@ntNet-CrowdSWE dataset is available [on zenodo](https://zenodo.org/records/10782465) with more than 6.5M items, 850K users and 11K classes.\n",
14+
"\n",
15+
"Each item (*e.g* a plant observation) has been labeled by at least a single user. The ground truth is simulated, so everything is known to measure the accuracy (amongst other metrics). Each item has an authoring user (the picture is taken and uploaded by a user). In the algorithm authoring users and users that vote on others' items are treated differently. "
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"import numpy as np\n",
25+
"import peerannot.models as pmod\n",
26+
"from tqdm.auto import tqdm\n",
27+
"import matplotlib.pyplot as plt\n",
28+
"\n",
29+
"# Crowdsourced answers (are stored typically in a .json file)\n",
30+
"votes = {\n",
31+
" 0: {0: 2, 1: 2, 2: 2},\n",
32+
" 1: {0: 6, 1: 2, 3: 2},\n",
33+
" 2: {1: 8, 2: 7, 3: 8},\n",
34+
" 3: {0: 1, 1: 1, 2: 5},\n",
35+
" 4: {2: 4},\n",
36+
" 5: {0: 0, 1: 0, 2: 1, 3: 6},\n",
37+
" 6: {1: 5, 3: 3},\n",
38+
" 7: {0: 3, 2: 6, 3: 4},\n",
39+
" 8: {1: 7, 3: 7},\n",
40+
" 9: {0: 8, 2: 1, 3: 1},\n",
41+
" 10: {0: 0, 1: 0, 2: 1},\n",
42+
" 11: {2: 3},\n",
43+
" 12: {0: 7, 2: 8, 3: 1},\n",
44+
" 13: {1: 3},\n",
45+
" 14: {0: 5, 2: 4, 3: 4},\n",
46+
" 15: {0: 5, 1: 7},\n",
47+
" 16: {0: 0, 1: 4, 3: 4},\n",
48+
" 17: {1: 5, 2: 7, 3: 7},\n",
49+
" 18: {0: 3},\n",
50+
" 19: {1: 7, 2: 7},\n",
51+
"}\n",
52+
"\n",
53+
"# Ground truth (gt) and authors of the observations\n",
54+
"authors = [0, 0, 1, 0, 2, 0, 1, 0, 3, 1, 1, 3, 0, 1, 0, 1, 0, 1, 0, 1]\n",
55+
"gt = [2, 6, 4, 1, 1, -1, 3, -1, 2, 8, 4, 1, 7, 0, 5, 5, 0, -1, 6, 7]\n",
56+
"np.savetxt(\"authors_toy.txt\", authors, fmt=\"%i\")"
57+
]
58+
},
59+
{
60+
"cell_type": "markdown",
61+
"metadata": {},
62+
"source": [
63+
"We will evaluate the performance of the method on two subsets:\n",
64+
"- The full dataset\n",
65+
"- The subset where the items have been voted on by more than two users \n",
66+
"We also monitor the proportion of classes retrieved after the aggregation compared to the ground truth (if a class is never predicted by the aggregation, a model can later never be trained to recognize it)."
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": null,
72+
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"def build_mask_more_than_two(answers, gt):\n",
76+
" mask = np.zeros(len(answers), dtype=bool)\n",
77+
" for tt in tqdm(answers.keys()):\n",
78+
" if len(answers[tt]) >= 2 and gt[int(tt)] != -1:\n",
79+
" mask[int(tt)] = 1\n",
80+
" return mask\n",
81+
"\n",
82+
"\n",
83+
"mask_more_than_two = build_mask_more_than_two(votes, gt)\n",
84+
"\n",
85+
"\n",
86+
"def build_mask_more_than_two(answers, gt):\n",
87+
" mask = np.zeros(len(answers), dtype=bool)\n",
88+
" for tt in tqdm(answers.keys()):\n",
89+
" if len(answers[tt]) >= 2 and gt[int(tt)] != -1:\n",
90+
" mask[int(tt)] = 1\n",
91+
" return mask\n",
92+
"\n",
93+
"\n",
94+
"mask_more_than_two = build_mask_more_than_two(votes, gt)\n",
95+
"\n",
96+
"# %% Metric to compare the strategies where the ground truth is available (proportion of classes kept and accuracy)\n",
97+
"\n",
98+
"\n",
99+
"def vol_class_kept(preds, truth, mask):\n",
100+
" uni_test = np.unique(truth[mask])\n",
101+
" n_class_test = uni_test.shape[0]\n",
102+
" preds_uni = np.unique(preds[mask])\n",
103+
" if preds_uni[0] == -1:\n",
104+
" preds_uni = preds_uni[1:]\n",
105+
" n_class_pred = preds_uni.shape[0]\n",
106+
" n_common = len(set(preds_uni).intersection(set(uni_test)))\n",
107+
" vol_kept = n_common / n_class_test * 100\n",
108+
" return n_class_pred, n_class_test, vol_kept\n",
109+
"\n",
110+
"\n",
111+
"def accuracy(preds, truth, mask):\n",
112+
" return np.mean(preds[mask] == truth[mask])\n",
113+
"\n",
114+
"\n",
115+
"# %% Metric to compare the strategies where the ground truth is available (proportion of classes kept and accuracy)\n",
116+
"\n",
117+
"\n",
118+
"def vol_class_kept(preds, truth, mask):\n",
119+
" uni_test = np.unique(truth[mask])\n",
120+
" n_class_test = uni_test.shape[0]\n",
121+
" preds_uni = np.unique(preds[mask])\n",
122+
" if preds_uni[0] == -1:\n",
123+
" preds_uni = preds_uni[1:]\n",
124+
" n_class_pred = preds_uni.shape[0]\n",
125+
" n_common = len(set(preds_uni).intersection(set(uni_test)))\n",
126+
" vol_kept = n_common / n_class_test * 100\n",
127+
" return n_class_pred, n_class_test, vol_kept\n",
128+
"\n",
129+
"\n",
130+
"def accuracy(preds, truth, mask):\n",
131+
" return np.mean(preds[mask] == truth[mask])"
132+
]
133+
},
134+
{
135+
"cell_type": "markdown",
136+
"metadata": {},
137+
"source": [
138+
"We now run the Pl@ntNet strategy against other strategies available in `peerannot`."
139+
]
140+
},
141+
{
142+
"cell_type": "markdown",
143+
"metadata": {},
144+
"source": [
145+
"Each strategy is first instanciated. The `.run` method is called if any optimization procedure is necessary. Estimated labels are recovered with the `.get_answers()` method."
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {
152+
"vscode": {
153+
"languageId": "plaintext"
154+
}
155+
},
156+
"outputs": [],
157+
"source": [
158+
"mv = pmod.MV(answers=votes, n_classes=9, n_workers=4)\n",
159+
"yhat_mv = mv.get_answers()\n",
160+
"wawa = pmod.Wawa(answers=votes, n_classes=9, n_workers=4)\n",
161+
"wawa.run()\n",
162+
"yhat_wawa = wawa.get_answers()\n",
163+
"twothird = pmod.TwoThird(answers=votes, n_classes=9, n_workers=4)\n",
164+
"yhat_twothird = twothird.get_answers()\n",
165+
"\n",
166+
"# %% run the PlantNet aggregatio\n",
167+
"pn = pmod.PlantNet(\n",
168+
" answers=votes,\n",
169+
" n_classes=9,\n",
170+
" n_workers=4,\n",
171+
" alpha=0.5,\n",
172+
" beta=0.2,\n",
173+
" authors=\"authors_toy.txt\",\n",
174+
")\n",
175+
"pn.run(maxiter=5, epsilon=1e-9)\n",
176+
"yhatpn = pn.get_answers()"
177+
]
178+
},
179+
{
180+
"cell_type": "markdown",
181+
"metadata": {},
182+
"source": [
183+
"Finally we plot the metrics considered"
184+
]
185+
},
186+
{
187+
"cell_type": "code",
188+
"execution_count": null,
189+
"metadata": {
190+
"vscode": {
191+
"languageId": "plaintext"
192+
}
193+
},
194+
"outputs": [],
195+
"source": [
196+
"# %% Compute the metrics for each strategy\n",
197+
"res_full = []\n",
198+
"res_more_than_two = []\n",
199+
"vol_class_full = []\n",
200+
"vol_class_more_than_two = []\n",
201+
"\n",
202+
"gt = np.array(gt)\n",
203+
"strats = [\"MV\", \"WAWA\", \"TwoThird\", \"PlantNet\"]\n",
204+
"\n",
205+
"for strat, res in zip(strats, [yhat_mv, yhat_wawa, yhat_twothird, yhatpn]):\n",
206+
" res_full.append(accuracy(res, gt, np.ones(len(gt), dtype=bool)))\n",
207+
" vol_class_full.append(vol_class_kept(res, gt, np.ones(len(gt), dtype=bool))[2])\n",
208+
" res_more_than_two.append(accuracy(res, gt, mask_more_than_two))\n",
209+
" vol_class_more_than_two.append(vol_class_kept(res, gt, mask_more_than_two)[2])\n",
210+
"# %% Plot the accuracy against the proportion of classes kept\n",
211+
"plt.figure()\n",
212+
"for i, strat in enumerate(strats):\n",
213+
" plt.scatter(vol_class_full[i], res_full[i], label=strat)\n",
214+
"plt.title(r\"Full dataset\")\n",
215+
"plt.ylabel(\"Accuracy\")\n",
216+
"plt.xlabel(\"Proportion of classes kept (%)\")\n",
217+
"plt.legend()\n",
218+
"plt.show()\n",
219+
"\n",
220+
"plt.figure()\n",
221+
"for i, strat in enumerate(strats):\n",
222+
" plt.scatter(vol_class_more_than_two[i], res_more_than_two[i], label=strat)\n",
223+
"plt.title(r\"Dataset with at least 2 annotations per observation\")\n",
224+
"plt.ylabel(\"Accuracy\")\n",
225+
"plt.xlabel(\"Proportion of classes kept (%)\")\n",
226+
"plt.legend()\n",
227+
"plt.show()"
228+
]
229+
}
230+
],
231+
"metadata": {
232+
"language_info": {
233+
"name": "python"
234+
}
235+
},
236+
"nbformat": 4,
237+
"nbformat_minor": 2
238+
}

0 commit comments

Comments
 (0)