Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from typing import Optional
3import pyarrow.parquet as pq
4import torch_geometric.transforms as T
5from torch_geometric.data import Dataset
7from proteinsolver import settings
8from proteinsolver.datasets.protein import iter_parquet_file, row_to_data, transform_edge_attr
11class ProteinDataset2(Dataset):
12 def __init__(
13 self,
14 root,
15 subset: Optional[str] = None,
16 data_url: Optional[str] = None,
17 transform=None,
18 pre_transform=None,
19 pre_filter=None,
20 ) -> None:
21 """Create new SudokuDataset."""
22 if data_url is None:
23 assert subset is not None
24 file_name = f"training_data_rs{int(subset.split('_')[-1])}.parquet"
25 self.data_url = f"{settings.data_url}/deep-protein-gen/{file_name}"
26 else:
27 self.data_url = data_url
28 self._raw_file_names = [self.data_url.rsplit("/")[-1]]
29 transform = T.Compose(
30 [transform_edge_attr] + ([transform] if transform is not None else [])
31 )
32 super().__init__(root, transform, pre_transform, pre_filter)
33 self.file = pq.ParquetFile(self.data_url)
34 self.reset()
36 def reset(self):
37 self.data_iterator = iter_parquet_file(self.data_url, [], {})
38 self.prev_index = None
40 def _download(self):
41 pass
43 def _process(self):
44 pass
46 def __len__(self):
47 # Warning: This over-estimates the number of data points because some rows are malformed
48 return self.file.metadata.num_rows
50 def get(self, idx):
51 if self.prev_index is None:
52 assert idx == 0, idx
53 else:
54 assert self.prev_index == idx - 1
55 self.prev_index = idx
57 while True:
58 tup = next(self.data_iterator)
59 data = row_to_data(tup)
60 if data is None:
61 continue
62 if self.pre_transform is not None:
63 data = self.pre_transform(data)
64 return data