-
Notifications
You must be signed in to change notification settings - Fork 135
Open
Description
Hello,
If I directly run this command suggested in the README:
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data
I get the following exeption:
Traceback (most recent call last):
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 491, in <module>
main()
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 454, in main
train(
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 114, in train
run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 38, in run_epoch
unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/utils.py", line 393, in __init__
self.iter = iter(self.data_loader)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 442, in __iter__
return self._get_iterator()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1085, in __init__
self._reset(loader, first_iter=True)
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1118, in _reset
self._try_put_index()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1352, in _try_put_index
index = self._next_index()
File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
return next(self._sampler_iter) # may raise StopIteration
File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/wilds/common/data_loaders.py", line 131, in __iter__
groups_for_batch = np.random.choice(
File "mtrand.pyx", line 984, in numpy.random.mtrand.RandomState.choice
ValueError: Cannot take a larger sample than population when 'replace=False'
I think this occurs because there are only 2 unique years in the test_unlabeled split, but unlabeled_n_groups_per_batch is set to 8, so it tries to sample 8 years without replacement.
I was able to fix this by changing the argument unlabeled_n_groups_per_batch to 2, here: https://github.com/p-lambda/wilds/blob/main/examples/configs/datasets.py#L220
It would be great if this can be fixed. Thank you so much for releasing these wonderful datasets and baseline algorithms!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels