Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import heapq 

2from dataclasses import dataclass, field 

3from typing import Any, Optional, Tuple 

4 

5import torch 

6import torch.nn as nn 

7from torch_geometric.data import Data 

8 

9 

10def get_node_proba(net, x, edge_index, edge_attr, num_categories=20): 

11 raise Exception("Use get_node_outputs instead!") 

12 

13 

14def get_node_value(net, x, edge_index, edge_attr, num_categories=20): 

15 raise Exception("Use get_node_outputs instead!") 

16 

17 

18@torch.no_grad() 

19def get_node_outputs( 

20 net: nn.Module, 

21 x: torch.Tensor, 

22 edge_index: torch.Tensor, 

23 edge_attr: torch.Tensor, 

24 num_categories: int = 20, 

25 output_transform: Optional[str] = None, 

26 oneshot: bool = False, 

27) -> torch.Tensor: 

28 """Return network output for each node in the reference sequence. 

29 

30 Args: 

31 net: The network to use for making predictions. 

32 x: Node attributes for the target sequence. 

33 edge_index: Edge indices of the target sequence. 

34 edge_attr: Edge attributes of the target sequence. 

35 num_categories: The number of categories to which the network assigns individual nodes 

36 (e.g. the number of amino acids for the protein design problem). 

37 output_transform: Transformation to apply to network outputs. 

38 - `None` - No transformation. 

39 - `proba` - Apply the softmax transformation. 

40 - `logproba` - Apply the softmax transformation and log the results. 

41 oneshot: Whether predictions should be made using a single pass through the network, 

42 or incrementally, by making a single prediction at a time. 

43 

44 Returns: 

45 A tensor of network predictions for each node in `x`. 

46 """ 

47 assert output_transform in [None, "proba", "logproba"] 

48 

49 x_ref = x 

50 x = torch.ones_like(x_ref) * num_categories 

51 x_proba = torch.zeros_like(x_ref).to(torch.float) 

52 index_array_ref = torch.arange(x_ref.size(0)) 

53 mask = x == num_categories 

54 while mask.any(): 

55 output = net(x, edge_index, edge_attr) 

56 if output_transform == "proba": 

57 output = torch.softmax(output, dim=1) 

58 elif output_transform == "logproba": 

59 output = torch.softmax(output, dim=1).log() 

60 

61 output_for_x = output.gather(1, x_ref.view(-1, 1)) 

62 

63 if oneshot: 

64 return output_for_x.data.cpu() 

65 

66 output_for_x = output_for_x[mask] 

67 index_array = index_array_ref[mask] 

68 max_proba, max_proba_position = output_for_x.max(dim=0) 

69 

70 assert x[index_array[max_proba_position]] == num_categories 

71 assert x_proba[index_array[max_proba_position]] == 0 

72 correct_amino_acid = x_ref[index_array[max_proba_position]].item() 

73 x[index_array[max_proba_position]] = correct_amino_acid 

74 assert output[index_array[max_proba_position], correct_amino_acid] == max_proba 

75 x_proba[index_array[max_proba_position]] = max_proba 

76 mask = x == num_categories 

77 return x_proba.data.cpu() 

78 

79 

80@torch.no_grad() 

81def scan_with_mask( 

82 net: nn.Module, 

83 x: torch.Tensor, 

84 edge_index: torch.Tensor, 

85 edge_attr: torch.Tensor, 

86 num_categories: int = 20, 

87 output_transform: Optional[str] = None, 

88) -> torch.Tensor: 

89 """Generate an output for each node in the sequence by masking one node at a time.""" 

90 assert output_transform in [None, "proba", "logproba"] 

91 

92 x_ref = x 

93 output_for_mask = torch.zeros_like(x_ref).to(torch.float) 

94 for i in range(x_ref.size(0)): 

95 x = x_ref.clone() 

96 x[i] = num_categories 

97 output = net(x, edge_index, edge_attr) 

98 if output_transform == "proba": 

99 output = torch.softmax(output, dim=1) 

100 elif output_transform == "logproba": 

101 output = torch.softmax(output, dim=1).log() 

102 output_for_x = output.gather(1, x_ref.view(-1, 1)) 

103 output_for_mask[i] = output_for_x[i] 

104 return output_for_mask.data.cpu() 

105 

106 

107# === Protein design === 

108 

109 

110@torch.no_grad() 

111def design_sequence( 

112 net: nn.Module, 

113 data: Data, 

114 random_position: bool = False, 

115 value_selection_strategy: str = "map", 

116 num_categories: int = None, 

117 temperature: float = 1.0, 

118) -> Tuple[torch.Tensor, torch.Tensor]: 

119 """Generate new sequences. 

120 

121 Args: 

122 net: A trained neural network to use for designing sequences. 

123 data: The data on which to base new sequences. 

124 random_position: Whether the next position to explore should be selected at random 

125 or by selecting the position for which we have the most confident predictions. 

126 value_selection_strategy: Controls the strategy for generating new sequences: 

127 - "map" - Select the most probable residue each time. 

128 - "multinomial" - Sample residues according to the probability assigned 

129 by the network. 

130 - "ref" - Select the residue provided by the `data.x` reference. 

131 num_categories: The number of categories possible. 

132 If `None`, assume that the number of categories corresponds to the maximum value 

133 in `data.x`. 

134 

135 Returns: 

136 A torch tensor of designed sequences. 

137 """ 

138 assert value_selection_strategy in ("map", "multinomial", "ref") 

139 

140 if num_categories is None: 

141 num_categories = data.x.max().item() 

142 

143 if hasattr(data, "batch"): 

144 batch_size = data.batch.max().item() + 1 

145 else: 

146 batch_size = 1 

147 

148 x_ref = data.y if hasattr(data, "y") and data.y is not None else data.x 

149 x = torch.ones_like(data.x) * num_categories 

150 x_proba = torch.zeros_like(x).to(torch.float) 

151 

152 # First, gather probabilities for pre-assigned residues 

153 mask_filled = (x_ref != num_categories) & (x == num_categories) 

154 while mask_filled.any(): 

155 for ( 

156 max_proba_index, 

157 chosen_category, 

158 chosen_category_proba, 

159 ) in _select_residue_for_position( 

160 net, 

161 x, 

162 x_ref, 

163 data, 

164 batch_size, 

165 mask_filled, 

166 random_position, 

167 "ref", 

168 temperature=temperature, 

169 ): 

170 assert chosen_category != num_categories 

171 assert x[max_proba_index] == num_categories 

172 assert x_proba[max_proba_index] == 0 

173 x[max_proba_index] = chosen_category 

174 x_proba[max_proba_index] = chosen_category_proba 

175 mask_filled = (x_ref != num_categories) & (x == num_categories) 

176 assert (x == x_ref).all().item() 

177 

178 # Next, select residues for unassigned positions 

179 mask_empty = x == num_categories 

180 while mask_empty.any(): 

181 for ( 

182 max_proba_index, 

183 chosen_category, 

184 chosen_category_proba, 

185 ) in _select_residue_for_position( 

186 net, 

187 x, 

188 x_ref, 

189 data, 

190 batch_size, 

191 mask_empty, 

192 random_position, 

193 value_selection_strategy, 

194 temperature=temperature, 

195 ): 

196 assert chosen_category != num_categories 

197 assert x[max_proba_index] == num_categories 

198 assert x_proba[max_proba_index] == 0 

199 x[max_proba_index] = chosen_category 

200 x_proba[max_proba_index] = chosen_category_proba 

201 mask_empty = x == num_categories 

202 

203 return x.cpu(), x_proba.cpu() 

204 

205 

206def _select_residue_for_position( 

207 net, 

208 x, 

209 x_ref, 

210 data, 

211 batch_size, 

212 mask_ref, 

213 random_position, 

214 value_selection_strategy, 

215 temperature=1.0, 

216): 

217 """Predict a new residue for an unassigned position for each batch in `batch_size`.""" 

218 assert value_selection_strategy in ("map", "multinomial", "ref") 

219 

220 output = net(x, data.edge_index, data.edge_attr) 

221 output = output / temperature 

222 output_proba_ref = torch.softmax(output, dim=1) 

223 output_proba_max_ref, _ = output_proba_ref.max(dim=1) 

224 index_array_ref = torch.arange(x.size(0)) 

225 

226 for i in range(batch_size): 

227 mask = mask_ref 

228 if batch_size > 1: 

229 mask = mask & (data.batch == i) 

230 

231 index_array = index_array_ref[mask] 

232 max_probas = output_proba_max_ref[mask] 

233 

234 if random_position: 

235 selected_residue_subindex = torch.randint(0, max_probas.size(0), (1,)).item() 

236 max_proba_index = index_array[selected_residue_subindex] 

237 else: 

238 selected_residue_subindex = max_probas.argmax().item() 

239 max_proba_index = index_array[selected_residue_subindex] 

240 

241 category_probas = output_proba_ref[max_proba_index] 

242 

243 if value_selection_strategy == "map": 

244 chosen_category_proba, chosen_category = category_probas.max(dim=0) 

245 elif value_selection_strategy == "multinomial": 

246 chosen_category = torch.multinomial(category_probas, 1).item() 

247 chosen_category_proba = category_probas[chosen_category] 

248 elif value_selection_strategy == "ref": 

249 chosen_category = x_ref[max_proba_index] 

250 chosen_category_proba = category_probas[chosen_category] 

251 

252 yield max_proba_index, chosen_category, chosen_category_proba 

253 

254 

255# ASTAR approach 

256 

257 

258@torch.no_grad() 

259def get_descendents(net, x, x_proba, edge_index, edge_attr, cutoff): 

260 index_array = torch.arange(x.size(0)) 

261 mask = x == 20 

262 

263 output = net(x, edge_index, edge_attr) 

264 output = torch.softmax(output, dim=1) 

265 output = output[mask] 

266 index_array = index_array[mask] 

267 

268 max_proba, max_index = output.max(dim=1)[0].max(dim=0) 

269 row_with_max_proba = output[max_index] 

270 

271 sum_log_prob = x_proba.sum() 

272 assert sum_log_prob.item() <= 0, x_proba 

273 # p_cutoff = min(torch.exp(sum_log_prob), row_with_max_proba.max()).item() 

274 

275 children = [] 

276 for i, p in enumerate(row_with_max_proba): 

277 # if p < p_cutoff: 

278 # continue 

279 x_clone = x.clone() 

280 x_proba_clone = x_proba.clone() 

281 assert x_clone[index_array[max_index]] == 20 

282 assert x_proba_clone[index_array[max_index]] == cutoff 

283 x_clone[index_array[max_index]] = i 

284 x_proba_clone[index_array[max_index]] = torch.log(p) 

285 children.append((x_clone, x_proba_clone)) 

286 return children 

287 

288 

289@dataclass(order=True) 

290class PrioritizedItem: 

291 p: float 

292 x: Any = field(compare=False) 

293 x_proba: float = field(compare=False) 

294 

295 

296@torch.no_grad() 

297def design_protein(net, x, edge_index, edge_attr, results, cutoff): 

298 """Design protein sequences using a search strategy.""" 

299 x_proba = torch.ones_like(x).to(torch.float) * cutoff 

300 heap = [PrioritizedItem(0, x, x_proba)] 

301 i = 0 

302 while heap: 

303 item = heapq.heappop(heap) 

304 if i % 1000 == 0: 

305 print( 

306 f"i: {i}; p: {item.p:.4f}; num missing: {(item.x == 20).sum()}; " 

307 f"heap size: {len(heap):7d}; results size: {len(results)}" 

308 ) 

309 if not (item.x == 20).any(): 

310 results.append(item) 

311 else: 

312 children = get_descendents(net, item.x, item.x_proba, edge_index, edge_attr, cutoff) 

313 for x, x_proba in children: 

314 heapq.heappush(heap, PrioritizedItem(-x_proba.sum(), x, x_proba)) 

315 i += 1 

316 if len(heap) > 1_000_000: 

317 heap = heap[:700_000] 

318 heapq.heapify(heap) 

319 return results