Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import math 

2from pathlib import Path 

3from typing import Optional 

4 

5import pandas as pd 

6import pyarrow.parquet as pq 

7import torch 

8from torch_geometric.data import Data, Dataset, InMemoryDataset 

9 

10from proteinsolver import settings 

11from proteinsolver.datasets import download_url 

12from proteinsolver.datasets.sudoku import str_to_tensor 

13from proteinsolver.utils import construct_solved_sudoku, gen_sudoku_graph, gen_sudoku_graph_featured 

14 

15 

16class SudokuDataset4(torch.utils.data.IterableDataset): 

17 def __init__( 

18 self, root, subset=None, data_url=None, transform=None, pre_transform=None, pre_filter=None 

19 ) -> None: 

20 """Create new SudokuDataset.""" 

21 super().__init__() 

22 self.root = Path(root).expanduser().resolve().as_posix() 

23 self.transform = transform 

24 self.pre_transform = pre_transform 

25 self.pre_filter = pre_filter 

26 

27 if data_url is None: 

28 assert subset is not None 

29 file_name = f"{subset.replace('sudoku_', '')}.parquet" 

30 self.data_url = f"{settings.data_url}/deep-protein-gen/sudoku_difficult/{file_name}" 

31 else: 

32 self.data_url = data_url 

33 

34 self.sudoku_graph = torch.from_numpy(gen_sudoku_graph_featured()).to_sparse(2) 

35 self.file = pq.ParquetFile(self.data_url) 

36 

37 def __iter__(self): 

38 worker_info = torch.utils.data.get_worker_info() 

39 if worker_info is not None: 

40 num_row_groups_per_worker = int( 

41 math.ceil(self.file.num_row_groups / worker_info.num_workers) 

42 ) 

43 min_row_group_index = worker_info.id * num_row_groups_per_worker 

44 max_row_group_index = min( 

45 (worker_info.id + 1) * num_row_groups_per_worker, self.file.num_row_groups 

46 ) 

47 row_group_indices = [ 

48 i 

49 for i in range(0, self.file.num_row_groups) 

50 if min_row_group_index <= i < max_row_group_index 

51 ] 

52 else: 

53 row_group_indices = range(0, self.file.num_row_groups) 

54 

55 for row_group in row_group_indices: 

56 data_list = self._read_row_group(row_group) 

57 for data in data_list: 

58 data.edge_index = self.sudoku_graph.indices() 

59 data.edge_attr = self.sudoku_graph.values() 

60 yield data 

61 

62 def _read_row_group(self, row_group: int): 

63 df = self.file.read_row_group(row_group).to_pandas(integer_object_nulls=True) 

64 

65 data_list = [] 

66 for tup in df.itertuples(): 

67 puzzle = str_to_tensor(tup.puzzle) - 1 

68 puzzle = torch.where(puzzle >= 0, puzzle, torch.tensor(9)) 

69 solution = str_to_tensor(tup.solution) - 1 

70 data = Data(x=puzzle, y=solution) 

71 if self.pre_filter is not None: 

72 data = self.pre_filter(data) 

73 if self.pre_transform is not None: 

74 data = self.pre_transform(data) 

75 data_list.append(data) 

76 return data_list 

77 

78 

79class SudokuDataset3(Dataset): 

80 def __init__(self, root, transform=None, pre_transform=None, pre_filter=None) -> None: 

81 self._gen_puzzle = construct_solved_sudoku 

82 self._edge_index, _ = gen_sudoku_graph() 

83 super().__init__(root, transform, pre_transform, pre_filter) 

84 

85 @property 

86 def raw_file_names(self): 

87 return [] 

88 

89 @property 

90 def processed_file_names(self): 

91 return [] 

92 

93 def __len__(self): 

94 return 700_000 # To be consistent with SudokuDataset2 

95 

96 def get(self, idx): 

97 puzzle = torch.from_numpy(self._gen_puzzle().reshape(-1) - 1) 

98 data = Data(x=puzzle, edge_index=self._edge_index) 

99 return data 

100 

101 

102class SudokuDataset2(InMemoryDataset): 

103 def __init__( 

104 self, 

105 root, 

106 subset: Optional[str] = None, 

107 data_url: Optional[str] = None, 

108 make_local_copy: bool = False, 

109 transform=None, 

110 pre_transform=None, 

111 pre_filter=None, 

112 ) -> None: 

113 """Create new SudokuDataset.""" 

114 self.data_url = ( 

115 f"{settings.data_url}/deep-protein-gen/sudoku/sudoku_{subset}.csv.gz" 

116 if data_url is None 

117 else data_url 

118 ) 

119 self._raw_file_names = [self.data_url.rsplit("/")[-1]] 

120 self._edge_index, _ = gen_sudoku_graph() 

121 super().__init__(root, transform, pre_transform, pre_filter) 

122 self.data, self.slices = torch.load(self.processed_paths[0]) 

123 

124 @property 

125 def raw_file_names(self): 

126 return self._raw_file_names 

127 

128 @property 

129 def processed_file_names(self): 

130 return self._raw_file_names 

131 

132 def download(self): 

133 download_url(self.data_url, self.raw_dir) 

134 

135 def process(self): 

136 df = pd.read_csv(self.raw_paths[0], index_col=False) 

137 df = df.rename(columns={"puzzle": "quizzes", "solution": "solutions"}) 

138 

139 data_list = [] 

140 for tup in df.itertuples(): 

141 solution = str_to_tensor(tup.solutions) - 1 

142 if hasattr(tup, "quizzes"): 

143 quiz = str_to_tensor(tup.quizzes) - 1 

144 quiz = torch.where(quiz >= 0, quiz, torch.tensor(9)) 

145 data = Data(x=quiz, y=solution) 

146 else: 

147 data = Data(x=solution) 

148 data_list.append(data) 

149 

150 if self.pre_filter is not None: 

151 data_list = [data for data in data_list if self.pre_filter(data)] 

152 

153 if self.pre_transform is not None: 

154 data_list = [self.pre_transform(data) for data in data_list] 

155 

156 data, slices = self.collate(data_list) 

157 torch.save((data, slices), self.processed_paths[0]) 

158 

159 def get(self, idx): 

160 data = super().get(idx) 

161 data.edge_index = self._edge_index 

162 return data