Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import math
2from pathlib import Path
3from typing import Optional
5import pandas as pd
6import pyarrow.parquet as pq
7import torch
8from torch_geometric.data import Data, Dataset, InMemoryDataset
10from proteinsolver import settings
11from proteinsolver.datasets import download_url
12from proteinsolver.datasets.sudoku import str_to_tensor
13from proteinsolver.utils import construct_solved_sudoku, gen_sudoku_graph, gen_sudoku_graph_featured
16class SudokuDataset4(torch.utils.data.IterableDataset):
17 def __init__(
18 self, root, subset=None, data_url=None, transform=None, pre_transform=None, pre_filter=None
19 ) -> None:
20 """Create new SudokuDataset."""
21 super().__init__()
22 self.root = Path(root).expanduser().resolve().as_posix()
23 self.transform = transform
24 self.pre_transform = pre_transform
25 self.pre_filter = pre_filter
27 if data_url is None:
28 assert subset is not None
29 file_name = f"{subset.replace('sudoku_', '')}.parquet"
30 self.data_url = f"{settings.data_url}/deep-protein-gen/sudoku_difficult/{file_name}"
31 else:
32 self.data_url = data_url
34 self.sudoku_graph = torch.from_numpy(gen_sudoku_graph_featured()).to_sparse(2)
35 self.file = pq.ParquetFile(self.data_url)
37 def __iter__(self):
38 worker_info = torch.utils.data.get_worker_info()
39 if worker_info is not None:
40 num_row_groups_per_worker = int(
41 math.ceil(self.file.num_row_groups / worker_info.num_workers)
42 )
43 min_row_group_index = worker_info.id * num_row_groups_per_worker
44 max_row_group_index = min(
45 (worker_info.id + 1) * num_row_groups_per_worker, self.file.num_row_groups
46 )
47 row_group_indices = [
48 i
49 for i in range(0, self.file.num_row_groups)
50 if min_row_group_index <= i < max_row_group_index
51 ]
52 else:
53 row_group_indices = range(0, self.file.num_row_groups)
55 for row_group in row_group_indices:
56 data_list = self._read_row_group(row_group)
57 for data in data_list:
58 data.edge_index = self.sudoku_graph.indices()
59 data.edge_attr = self.sudoku_graph.values()
60 yield data
62 def _read_row_group(self, row_group: int):
63 df = self.file.read_row_group(row_group).to_pandas(integer_object_nulls=True)
65 data_list = []
66 for tup in df.itertuples():
67 puzzle = str_to_tensor(tup.puzzle) - 1
68 puzzle = torch.where(puzzle >= 0, puzzle, torch.tensor(9))
69 solution = str_to_tensor(tup.solution) - 1
70 data = Data(x=puzzle, y=solution)
71 if self.pre_filter is not None:
72 data = self.pre_filter(data)
73 if self.pre_transform is not None:
74 data = self.pre_transform(data)
75 data_list.append(data)
76 return data_list
79class SudokuDataset3(Dataset):
80 def __init__(self, root, transform=None, pre_transform=None, pre_filter=None) -> None:
81 self._gen_puzzle = construct_solved_sudoku
82 self._edge_index, _ = gen_sudoku_graph()
83 super().__init__(root, transform, pre_transform, pre_filter)
85 @property
86 def raw_file_names(self):
87 return []
89 @property
90 def processed_file_names(self):
91 return []
93 def __len__(self):
94 return 700_000 # To be consistent with SudokuDataset2
96 def get(self, idx):
97 puzzle = torch.from_numpy(self._gen_puzzle().reshape(-1) - 1)
98 data = Data(x=puzzle, edge_index=self._edge_index)
99 return data
102class SudokuDataset2(InMemoryDataset):
103 def __init__(
104 self,
105 root,
106 subset: Optional[str] = None,
107 data_url: Optional[str] = None,
108 make_local_copy: bool = False,
109 transform=None,
110 pre_transform=None,
111 pre_filter=None,
112 ) -> None:
113 """Create new SudokuDataset."""
114 self.data_url = (
115 f"{settings.data_url}/deep-protein-gen/sudoku/sudoku_{subset}.csv.gz"
116 if data_url is None
117 else data_url
118 )
119 self._raw_file_names = [self.data_url.rsplit("/")[-1]]
120 self._edge_index, _ = gen_sudoku_graph()
121 super().__init__(root, transform, pre_transform, pre_filter)
122 self.data, self.slices = torch.load(self.processed_paths[0])
124 @property
125 def raw_file_names(self):
126 return self._raw_file_names
128 @property
129 def processed_file_names(self):
130 return self._raw_file_names
132 def download(self):
133 download_url(self.data_url, self.raw_dir)
135 def process(self):
136 df = pd.read_csv(self.raw_paths[0], index_col=False)
137 df = df.rename(columns={"puzzle": "quizzes", "solution": "solutions"})
139 data_list = []
140 for tup in df.itertuples():
141 solution = str_to_tensor(tup.solutions) - 1
142 if hasattr(tup, "quizzes"):
143 quiz = str_to_tensor(tup.quizzes) - 1
144 quiz = torch.where(quiz >= 0, quiz, torch.tensor(9))
145 data = Data(x=quiz, y=solution)
146 else:
147 data = Data(x=solution)
148 data_list.append(data)
150 if self.pre_filter is not None:
151 data_list = [data for data in data_list if self.pre_filter(data)]
153 if self.pre_transform is not None:
154 data_list = [self.pre_transform(data) for data in data_list]
156 data, slices = self.collate(data_list)
157 torch.save((data, slices), self.processed_paths[0])
159 def get(self, idx):
160 data = super().get(idx)
161 data.edge_index = self._edge_index
162 return data