Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import heapq
2from dataclasses import dataclass, field
3from typing import Any, Optional, Tuple
5import torch
6import torch.nn as nn
7from torch_geometric.data import Data
10def get_node_proba(net, x, edge_index, edge_attr, num_categories=20):
11 raise Exception("Use get_node_outputs instead!")
14def get_node_value(net, x, edge_index, edge_attr, num_categories=20):
15 raise Exception("Use get_node_outputs instead!")
18@torch.no_grad()
19def get_node_outputs(
20 net: nn.Module,
21 x: torch.Tensor,
22 edge_index: torch.Tensor,
23 edge_attr: torch.Tensor,
24 num_categories: int = 20,
25 output_transform: Optional[str] = None,
26 oneshot: bool = False,
27) -> torch.Tensor:
28 """Return network output for each node in the reference sequence.
30 Args:
31 net: The network to use for making predictions.
32 x: Node attributes for the target sequence.
33 edge_index: Edge indices of the target sequence.
34 edge_attr: Edge attributes of the target sequence.
35 num_categories: The number of categories to which the network assigns individual nodes
36 (e.g. the number of amino acids for the protein design problem).
37 output_transform: Transformation to apply to network outputs.
38 - `None` - No transformation.
39 - `proba` - Apply the softmax transformation.
40 - `logproba` - Apply the softmax transformation and log the results.
41 oneshot: Whether predictions should be made using a single pass through the network,
42 or incrementally, by making a single prediction at a time.
44 Returns:
45 A tensor of network predictions for each node in `x`.
46 """
47 assert output_transform in [None, "proba", "logproba"]
49 x_ref = x
50 x = torch.ones_like(x_ref) * num_categories
51 x_proba = torch.zeros_like(x_ref).to(torch.float)
52 index_array_ref = torch.arange(x_ref.size(0))
53 mask = x == num_categories
54 while mask.any():
55 output = net(x, edge_index, edge_attr)
56 if output_transform == "proba":
57 output = torch.softmax(output, dim=1)
58 elif output_transform == "logproba":
59 output = torch.softmax(output, dim=1).log()
61 output_for_x = output.gather(1, x_ref.view(-1, 1))
63 if oneshot:
64 return output_for_x.data.cpu()
66 output_for_x = output_for_x[mask]
67 index_array = index_array_ref[mask]
68 max_proba, max_proba_position = output_for_x.max(dim=0)
70 assert x[index_array[max_proba_position]] == num_categories
71 assert x_proba[index_array[max_proba_position]] == 0
72 correct_amino_acid = x_ref[index_array[max_proba_position]].item()
73 x[index_array[max_proba_position]] = correct_amino_acid
74 assert output[index_array[max_proba_position], correct_amino_acid] == max_proba
75 x_proba[index_array[max_proba_position]] = max_proba
76 mask = x == num_categories
77 return x_proba.data.cpu()
80@torch.no_grad()
81def scan_with_mask(
82 net: nn.Module,
83 x: torch.Tensor,
84 edge_index: torch.Tensor,
85 edge_attr: torch.Tensor,
86 num_categories: int = 20,
87 output_transform: Optional[str] = None,
88) -> torch.Tensor:
89 """Generate an output for each node in the sequence by masking one node at a time."""
90 assert output_transform in [None, "proba", "logproba"]
92 x_ref = x
93 output_for_mask = torch.zeros_like(x_ref).to(torch.float)
94 for i in range(x_ref.size(0)):
95 x = x_ref.clone()
96 x[i] = num_categories
97 output = net(x, edge_index, edge_attr)
98 if output_transform == "proba":
99 output = torch.softmax(output, dim=1)
100 elif output_transform == "logproba":
101 output = torch.softmax(output, dim=1).log()
102 output_for_x = output.gather(1, x_ref.view(-1, 1))
103 output_for_mask[i] = output_for_x[i]
104 return output_for_mask.data.cpu()
107# === Protein design ===
110@torch.no_grad()
111def design_sequence(
112 net: nn.Module,
113 data: Data,
114 random_position: bool = False,
115 value_selection_strategy: str = "map",
116 num_categories: int = None,
117 temperature: float = 1.0,
118) -> Tuple[torch.Tensor, torch.Tensor]:
119 """Generate new sequences.
121 Args:
122 net: A trained neural network to use for designing sequences.
123 data: The data on which to base new sequences.
124 random_position: Whether the next position to explore should be selected at random
125 or by selecting the position for which we have the most confident predictions.
126 value_selection_strategy: Controls the strategy for generating new sequences:
127 - "map" - Select the most probable residue each time.
128 - "multinomial" - Sample residues according to the probability assigned
129 by the network.
130 - "ref" - Select the residue provided by the `data.x` reference.
131 num_categories: The number of categories possible.
132 If `None`, assume that the number of categories corresponds to the maximum value
133 in `data.x`.
135 Returns:
136 A torch tensor of designed sequences.
137 """
138 assert value_selection_strategy in ("map", "multinomial", "ref")
140 if num_categories is None:
141 num_categories = data.x.max().item()
143 if hasattr(data, "batch"):
144 batch_size = data.batch.max().item() + 1
145 else:
146 batch_size = 1
148 x_ref = data.y if hasattr(data, "y") and data.y is not None else data.x
149 x = torch.ones_like(data.x) * num_categories
150 x_proba = torch.zeros_like(x).to(torch.float)
152 # First, gather probabilities for pre-assigned residues
153 mask_filled = (x_ref != num_categories) & (x == num_categories)
154 while mask_filled.any():
155 for (
156 max_proba_index,
157 chosen_category,
158 chosen_category_proba,
159 ) in _select_residue_for_position(
160 net,
161 x,
162 x_ref,
163 data,
164 batch_size,
165 mask_filled,
166 random_position,
167 "ref",
168 temperature=temperature,
169 ):
170 assert chosen_category != num_categories
171 assert x[max_proba_index] == num_categories
172 assert x_proba[max_proba_index] == 0
173 x[max_proba_index] = chosen_category
174 x_proba[max_proba_index] = chosen_category_proba
175 mask_filled = (x_ref != num_categories) & (x == num_categories)
176 assert (x == x_ref).all().item()
178 # Next, select residues for unassigned positions
179 mask_empty = x == num_categories
180 while mask_empty.any():
181 for (
182 max_proba_index,
183 chosen_category,
184 chosen_category_proba,
185 ) in _select_residue_for_position(
186 net,
187 x,
188 x_ref,
189 data,
190 batch_size,
191 mask_empty,
192 random_position,
193 value_selection_strategy,
194 temperature=temperature,
195 ):
196 assert chosen_category != num_categories
197 assert x[max_proba_index] == num_categories
198 assert x_proba[max_proba_index] == 0
199 x[max_proba_index] = chosen_category
200 x_proba[max_proba_index] = chosen_category_proba
201 mask_empty = x == num_categories
203 return x.cpu(), x_proba.cpu()
206def _select_residue_for_position(
207 net,
208 x,
209 x_ref,
210 data,
211 batch_size,
212 mask_ref,
213 random_position,
214 value_selection_strategy,
215 temperature=1.0,
216):
217 """Predict a new residue for an unassigned position for each batch in `batch_size`."""
218 assert value_selection_strategy in ("map", "multinomial", "ref")
220 output = net(x, data.edge_index, data.edge_attr)
221 output = output / temperature
222 output_proba_ref = torch.softmax(output, dim=1)
223 output_proba_max_ref, _ = output_proba_ref.max(dim=1)
224 index_array_ref = torch.arange(x.size(0))
226 for i in range(batch_size):
227 mask = mask_ref
228 if batch_size > 1:
229 mask = mask & (data.batch == i)
231 index_array = index_array_ref[mask]
232 max_probas = output_proba_max_ref[mask]
234 if random_position:
235 selected_residue_subindex = torch.randint(0, max_probas.size(0), (1,)).item()
236 max_proba_index = index_array[selected_residue_subindex]
237 else:
238 selected_residue_subindex = max_probas.argmax().item()
239 max_proba_index = index_array[selected_residue_subindex]
241 category_probas = output_proba_ref[max_proba_index]
243 if value_selection_strategy == "map":
244 chosen_category_proba, chosen_category = category_probas.max(dim=0)
245 elif value_selection_strategy == "multinomial":
246 chosen_category = torch.multinomial(category_probas, 1).item()
247 chosen_category_proba = category_probas[chosen_category]
248 elif value_selection_strategy == "ref":
249 chosen_category = x_ref[max_proba_index]
250 chosen_category_proba = category_probas[chosen_category]
252 yield max_proba_index, chosen_category, chosen_category_proba
255# ASTAR approach
258@torch.no_grad()
259def get_descendents(net, x, x_proba, edge_index, edge_attr, cutoff):
260 index_array = torch.arange(x.size(0))
261 mask = x == 20
263 output = net(x, edge_index, edge_attr)
264 output = torch.softmax(output, dim=1)
265 output = output[mask]
266 index_array = index_array[mask]
268 max_proba, max_index = output.max(dim=1)[0].max(dim=0)
269 row_with_max_proba = output[max_index]
271 sum_log_prob = x_proba.sum()
272 assert sum_log_prob.item() <= 0, x_proba
273 # p_cutoff = min(torch.exp(sum_log_prob), row_with_max_proba.max()).item()
275 children = []
276 for i, p in enumerate(row_with_max_proba):
277 # if p < p_cutoff:
278 # continue
279 x_clone = x.clone()
280 x_proba_clone = x_proba.clone()
281 assert x_clone[index_array[max_index]] == 20
282 assert x_proba_clone[index_array[max_index]] == cutoff
283 x_clone[index_array[max_index]] = i
284 x_proba_clone[index_array[max_index]] = torch.log(p)
285 children.append((x_clone, x_proba_clone))
286 return children
289@dataclass(order=True)
290class PrioritizedItem:
291 p: float
292 x: Any = field(compare=False)
293 x_proba: float = field(compare=False)
296@torch.no_grad()
297def design_protein(net, x, edge_index, edge_attr, results, cutoff):
298 """Design protein sequences using a search strategy."""
299 x_proba = torch.ones_like(x).to(torch.float) * cutoff
300 heap = [PrioritizedItem(0, x, x_proba)]
301 i = 0
302 while heap:
303 item = heapq.heappop(heap)
304 if i % 1000 == 0:
305 print(
306 f"i: {i}; p: {item.p:.4f}; num missing: {(item.x == 20).sum()}; "
307 f"heap size: {len(heap):7d}; results size: {len(results)}"
308 )
309 if not (item.x == 20).any():
310 results.append(item)
311 else:
312 children = get_descendents(net, item.x, item.x_proba, edge_index, edge_attr, cutoff)
313 for x, x_proba in children:
314 heapq.heappush(heap, PrioritizedItem(-x_proba.sum(), x, x_proba))
315 i += 1
316 if len(heap) > 1_000_000:
317 heap = heap[:700_000]
318 heapq.heapify(heap)
319 return results