Description
I migrate my question from huggingface/transformers#4009 (comment)
I tried to train a Roberta from scratch using transformers. But I got OOM issues with loading a large text file.
According to the suggestion from @thomwolf , I tried to implement datasets
to load my text file. This test.txt is a simple sample where each line is a sentence.
from datasets import load_dataset
dataset = load_dataset('text', data_files='test.txt',cache_dir="./")
dataset.set_format(type='torch',columns=["text"])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
next(iter(dataloader))
But dataload cannot yield sample and error is:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-12-388aca337e2f> in <module>
----> 1 next(iter(dataloader))
/Library/Python/3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
/Library/Python/3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
401 def _next_data(self):
402 index = self._next_index() # may raise StopIteration
--> 403 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
404 if self._pin_memory:
405 data = _utils.pin_memory.pin_memory(data)
/Library/Python/3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/Library/Python/3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
KeyError: 0
dataset.set_format(type='torch',columns=["text"])
returns a log says:
Set __getitem__(key) output type to torch for ['text'] columns (when key is int or slice) and don't output other (un-formatted) columns.
I noticed the dataset is DatasetDict({'train': Dataset(features: {'text': Value(dtype='string', id=None)}, num_rows: 44)})
.
Each sample can be accessed by dataset["train"]["text"]
instead of dataset["text"]
.
Could you please give me any suggestions on how to modify this code to load the text file?
Versions:
Python version 3.7.3
PyTorch version 1.6.0
TensorFlow version 2.3.0
datasets version: 1.0.1