Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1from typing import Optional 

2 

3import pyarrow.parquet as pq 

4import torch_geometric.transforms as T 

5from torch_geometric.data import Dataset 

6 

7from proteinsolver import settings 

8from proteinsolver.datasets.protein import iter_parquet_file, row_to_data, transform_edge_attr 

9 

10 

11class ProteinDataset2(Dataset): 

12 def __init__( 

13 self, 

14 root, 

15 subset: Optional[str] = None, 

16 data_url: Optional[str] = None, 

17 transform=None, 

18 pre_transform=None, 

19 pre_filter=None, 

20 ) -> None: 

21 """Create new SudokuDataset.""" 

22 if data_url is None: 

23 assert subset is not None 

24 file_name = f"training_data_rs{int(subset.split('_')[-1])}.parquet" 

25 self.data_url = f"{settings.data_url}/deep-protein-gen/{file_name}" 

26 else: 

27 self.data_url = data_url 

28 self._raw_file_names = [self.data_url.rsplit("/")[-1]] 

29 transform = T.Compose( 

30 [transform_edge_attr] + ([transform] if transform is not None else []) 

31 ) 

32 super().__init__(root, transform, pre_transform, pre_filter) 

33 self.file = pq.ParquetFile(self.data_url) 

34 self.reset() 

35 

36 def reset(self): 

37 self.data_iterator = iter_parquet_file(self.data_url, [], {}) 

38 self.prev_index = None 

39 

40 def _download(self): 

41 pass 

42 

43 def _process(self): 

44 pass 

45 

46 def __len__(self): 

47 # Warning: This over-estimates the number of data points because some rows are malformed 

48 return self.file.metadata.num_rows 

49 

50 def get(self, idx): 

51 if self.prev_index is None: 

52 assert idx == 0, idx 

53 else: 

54 assert self.prev_index == idx - 1 

55 self.prev_index = idx 

56 

57 while True: 

58 tup = next(self.data_iterator) 

59 data = row_to_data(tup) 

60 if data is None: 

61 continue 

62 if self.pre_transform is not None: 

63 data = self.pre_transform(data) 

64 return data