Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import codecs 

2import io 

3import re 

4from pathlib import Path 

5 

6import matplotlib.pyplot as plt 

7import numpy as np 

8import torch_geometric 

9from IPython.display import clear_output, display 

10from ipywidgets import Button, FileUpload, HBox, Layout 

11from kmbio import PDB 

12from kmtools import structure_tools 

13 

14import proteinsolver 

15from proteinsolver.dashboard import ( 

16 global_state, 

17 update_sequence_generation, 

18 update_target_selection, 

19) 

20 

21 

22def load_structure(structure: PDB.Structure): 

23 chain_id = next(next(structure.models).chains).id 

24 

25 domain, result_df = proteinsolver.utils.get_interaction_dataset_wdistances( 

26 structure, 0, chain_id, r_cutoff=12, remove_hetatms=True 

27 ) 

28 domain_sequence = structure_tools.get_chain_sequence(domain) 

29 assert max(result_df["residue_idx_1"].values) < len(domain_sequence) 

30 assert max(result_df["residue_idx_2"].values) < len(domain_sequence) 

31 

32 pdata = proteinsolver.utils.ProteinData( 

33 domain_sequence, 

34 result_df["residue_idx_1"].values, 

35 result_df["residue_idx_2"].values, 

36 result_df["distance"].values, 

37 ) 

38 tdata = proteinsolver.datasets.protein.row_to_data(pdata) 

39 data = proteinsolver.datasets.protein.transform_edge_attr(tdata.clone()) 

40 

41 global_state.structure = domain 

42 global_state.tdata = tdata 

43 global_state.data = data 

44 global_state.reference_sequence = list(proteinsolver.utils.array_to_seq(data.x)) 

45 global_state.target_sequence = ["-"] * len(global_state.reference_sequence) 

46 

47 

48def load_distance_matrix(distance_matrix: str): 

49 # Parse distance matrix file 

50 num_residues = None 

51 results = [] 

52 for line in distance_matrix.split("\n"): 

53 line = line.strip() 

54 if not line or line.startswith("#"): 

55 continue 

56 if line.startswith("N:"): 

57 row = re.split(": *", line) 

58 num_residues = int(row[1]) 

59 else: 

60 row = re.split(", *", line) 

61 residue_idx_1, residue_idx_2, distance = int(row[0]), int(row[1]), float(row[2]) 

62 if residue_idx_1 == residue_idx_2: 

63 continue 

64 elif residue_idx_1 < residue_idx_2: 

65 results.append((residue_idx_1, residue_idx_2, distance)) 

66 else: 

67 results.append((residue_idx_2, residue_idx_1, distance)) 

68 

69 # Remove duplicates 

70 results = list(set(results)) 

71 

72 if results: 

73 residue_idx_1_lst, residue_idx_2_lst, distance_lst = list(zip(*results)) 

74 else: 

75 residue_idx_1_lst, residue_idx_2_lst, distance_lst = [], [], [] 

76 

77 if num_residues is None: 

78 num_residues = max(residue_idx_1_lst + residue_idx_2_lst) 

79 

80 pdata = proteinsolver.utils.ProteinData( 

81 "G" * num_residues, 

82 np.array(residue_idx_1_lst), 

83 np.array(residue_idx_2_lst), 

84 np.array(distance_lst), 

85 ) 

86 tdata = proteinsolver.datasets.protein.row_to_data(pdata) 

87 data = proteinsolver.datasets.protein.transform_edge_attr(tdata.clone()) 

88 

89 global_state.structure = None 

90 global_state.tdata = tdata 

91 global_state.data = data 

92 global_state.reference_sequence = list(proteinsolver.utils.array_to_seq(data.x)) 

93 global_state.target_sequence = ["-"] * len(global_state.reference_sequence) 

94 

95 

96def update_displayed_structure(ngl_stage): 

97 if ngl_stage.n_components: 

98 ngl_stage.remove_component(ngl_stage.component_0) 

99 if global_state.structure is not None: 

100 ngl_stage.add_component(PDB.structure_to_ngl(global_state.structure)) 

101 

102 

103def update_displayed_distance_matrix(distance_matrix_out): 

104 tdata = global_state.tdata 

105 adj = torch_geometric.utils.to_dense_adj( 

106 edge_index=tdata.edge_index, edge_attr=1 / tdata.edge_attr[:, 0] 

107 ).squeeze() 

108 

109 fig = plt.figure(constrained_layout=False, figsize=(4 * 0.8, 3 * 0.8)) 

110 

111 gs = fig.add_gridspec( 

112 nrows=1, 

113 ncols=2, 

114 top=0.98, 

115 right=0.85, 

116 bottom=0.15, 

117 left=0.1, 

118 hspace=0, 

119 wspace=0, 

120 width_ratios=[3, 0.1], # 16 

121 ) 

122 

123 ax = fig.add_subplot(gs[0, 0]) 

124 cax = fig.add_subplot(gs[0, 1]) 

125 

126 out = ax.imshow(adj, cmap="Greys") 

127 ax.set_ylabel("Amino acid position") 

128 ax.set_xlabel("Amino acid position") 

129 ax.tick_params("both") 

130 cb = fig.colorbar(out, cax=cax) 

131 cb.set_label("1 / distance (Å$^{-1}$)") 

132 

133 with distance_matrix_out: 

134 clear_output() 

135 display(fig, display_id="distance-matrix") 

136 

137 

138def create_load_structure_button( 

139 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out 

140): 

141 uploader = FileUpload( 

142 description="Load structure", 

143 accept=".pdb,.cif,.mmcif", 

144 multiple=False, 

145 layout=Layout(width="11rem"), 

146 ) 

147 

148 def handle_upload(change): 

149 # Keep only the last file (there must be a better way!) 

150 last_item = list(change["new"].values())[-1] 

151 

152 filename = last_item["metadata"]["name"] 

153 structure_id = filename.split(".")[0] 

154 suffix = filename.split(".")[-1] 

155 

156 data = codecs.decode(last_item["content"], encoding="utf-8") 

157 buf = io.StringIO() 

158 buf.write(data) 

159 buf.seek(0) 

160 parser = PDB.get_parser(suffix) 

161 structure = parser.get_structure(buf, structure_id=structure_id) 

162 

163 # TODO: We may need to lock global_state at this point? 

164 load_structure(structure) 

165 

166 update_target_selection(target_selection_out) 

167 update_sequence_generation(sequence_generation_out) 

168 update_displayed_structure(ngl_stage) 

169 update_displayed_distance_matrix(distance_matrix_out) 

170 

171 uploader.value.clear() 

172 uploader._counter = 0 

173 

174 uploader.observe(handle_upload, names="value") 

175 return uploader 

176 

177 

178def create_load_distance_matrix_button( 

179 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out 

180): 

181 uploader = FileUpload( 

182 description="Load distance matrix", 

183 accept=".txt", 

184 multiple=False, 

185 layout=Layout(width="11rem"), 

186 ) 

187 

188 def handle_upload(change): 

189 # Keep only the last file (there must be a better way!) 

190 last_item = list(change["new"].values())[-1] 

191 

192 data = codecs.decode(last_item["content"], encoding="utf-8") 

193 

194 # TODO: We may need to lock global_state at this point? 

195 load_distance_matrix(data) 

196 

197 update_target_selection(target_selection_out) 

198 update_sequence_generation(sequence_generation_out) 

199 update_displayed_structure(ngl_stage) 

200 update_displayed_distance_matrix(distance_matrix_out) 

201 

202 uploader.value.clear() 

203 uploader._counter = 0 

204 

205 uploader.observe(handle_upload, names="value") 

206 return uploader 

207 

208 

209def create_load_example_buttons( 

210 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out 

211): 

212 examples_folder = ( 

213 Path(proteinsolver.__path__[0]).resolve(strict=True).joinpath("data", "inputs") 

214 ) 

215 examples = [ 

216 examples_folder.joinpath(file) 

217 for file in ["1n5uA03.pdb", "4beuA02.pdb", "4unuA00.pdb", "4z8jA00.pdb"] 

218 ] 

219 

220 def create_activate_example_button(filename): 

221 def on_example_clicked(change): 

222 structure = PDB.load(filename) 

223 # TODO: We may need to lock global_state at this point? 

224 load_structure(structure) 

225 update_target_selection(target_selection_out) 

226 update_sequence_generation(sequence_generation_out) 

227 update_displayed_structure(ngl_stage) 

228 update_displayed_distance_matrix(distance_matrix_out) 

229 

230 button = Button(description=filename.stem, layout=Layout(width="8.25rem")) 

231 button.on_click(on_example_clicked) 

232 return button 

233 

234 buttons = [create_activate_example_button(example) for example in examples] 

235 line = HBox(buttons, layout=Layout(flex_flow="row", align_items="center")) 

236 return line