Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import codecs
2import io
3import re
4from pathlib import Path
6import matplotlib.pyplot as plt
7import numpy as np
8import torch_geometric
9from IPython.display import clear_output, display
10from ipywidgets import Button, FileUpload, HBox, Layout
11from kmbio import PDB
12from kmtools import structure_tools
14import proteinsolver
15from proteinsolver.dashboard import (
16 global_state,
17 update_sequence_generation,
18 update_target_selection,
19)
22def load_structure(structure: PDB.Structure):
23 chain_id = next(next(structure.models).chains).id
25 domain, result_df = proteinsolver.utils.get_interaction_dataset_wdistances(
26 structure, 0, chain_id, r_cutoff=12, remove_hetatms=True
27 )
28 domain_sequence = structure_tools.get_chain_sequence(domain)
29 assert max(result_df["residue_idx_1"].values) < len(domain_sequence)
30 assert max(result_df["residue_idx_2"].values) < len(domain_sequence)
32 pdata = proteinsolver.utils.ProteinData(
33 domain_sequence,
34 result_df["residue_idx_1"].values,
35 result_df["residue_idx_2"].values,
36 result_df["distance"].values,
37 )
38 tdata = proteinsolver.datasets.protein.row_to_data(pdata)
39 data = proteinsolver.datasets.protein.transform_edge_attr(tdata.clone())
41 global_state.structure = domain
42 global_state.tdata = tdata
43 global_state.data = data
44 global_state.reference_sequence = list(proteinsolver.utils.array_to_seq(data.x))
45 global_state.target_sequence = ["-"] * len(global_state.reference_sequence)
48def load_distance_matrix(distance_matrix: str):
49 # Parse distance matrix file
50 num_residues = None
51 results = []
52 for line in distance_matrix.split("\n"):
53 line = line.strip()
54 if not line or line.startswith("#"):
55 continue
56 if line.startswith("N:"):
57 row = re.split(": *", line)
58 num_residues = int(row[1])
59 else:
60 row = re.split(", *", line)
61 residue_idx_1, residue_idx_2, distance = int(row[0]), int(row[1]), float(row[2])
62 if residue_idx_1 == residue_idx_2:
63 continue
64 elif residue_idx_1 < residue_idx_2:
65 results.append((residue_idx_1, residue_idx_2, distance))
66 else:
67 results.append((residue_idx_2, residue_idx_1, distance))
69 # Remove duplicates
70 results = list(set(results))
72 if results:
73 residue_idx_1_lst, residue_idx_2_lst, distance_lst = list(zip(*results))
74 else:
75 residue_idx_1_lst, residue_idx_2_lst, distance_lst = [], [], []
77 if num_residues is None:
78 num_residues = max(residue_idx_1_lst + residue_idx_2_lst)
80 pdata = proteinsolver.utils.ProteinData(
81 "G" * num_residues,
82 np.array(residue_idx_1_lst),
83 np.array(residue_idx_2_lst),
84 np.array(distance_lst),
85 )
86 tdata = proteinsolver.datasets.protein.row_to_data(pdata)
87 data = proteinsolver.datasets.protein.transform_edge_attr(tdata.clone())
89 global_state.structure = None
90 global_state.tdata = tdata
91 global_state.data = data
92 global_state.reference_sequence = list(proteinsolver.utils.array_to_seq(data.x))
93 global_state.target_sequence = ["-"] * len(global_state.reference_sequence)
96def update_displayed_structure(ngl_stage):
97 if ngl_stage.n_components:
98 ngl_stage.remove_component(ngl_stage.component_0)
99 if global_state.structure is not None:
100 ngl_stage.add_component(PDB.structure_to_ngl(global_state.structure))
103def update_displayed_distance_matrix(distance_matrix_out):
104 tdata = global_state.tdata
105 adj = torch_geometric.utils.to_dense_adj(
106 edge_index=tdata.edge_index, edge_attr=1 / tdata.edge_attr[:, 0]
107 ).squeeze()
109 fig = plt.figure(constrained_layout=False, figsize=(4 * 0.8, 3 * 0.8))
111 gs = fig.add_gridspec(
112 nrows=1,
113 ncols=2,
114 top=0.98,
115 right=0.85,
116 bottom=0.15,
117 left=0.1,
118 hspace=0,
119 wspace=0,
120 width_ratios=[3, 0.1], # 16
121 )
123 ax = fig.add_subplot(gs[0, 0])
124 cax = fig.add_subplot(gs[0, 1])
126 out = ax.imshow(adj, cmap="Greys")
127 ax.set_ylabel("Amino acid position")
128 ax.set_xlabel("Amino acid position")
129 ax.tick_params("both")
130 cb = fig.colorbar(out, cax=cax)
131 cb.set_label("1 / distance (Å$^{-1}$)")
133 with distance_matrix_out:
134 clear_output()
135 display(fig, display_id="distance-matrix")
138def create_load_structure_button(
139 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out
140):
141 uploader = FileUpload(
142 description="Load structure",
143 accept=".pdb,.cif,.mmcif",
144 multiple=False,
145 layout=Layout(width="11rem"),
146 )
148 def handle_upload(change):
149 # Keep only the last file (there must be a better way!)
150 last_item = list(change["new"].values())[-1]
152 filename = last_item["metadata"]["name"]
153 structure_id = filename.split(".")[0]
154 suffix = filename.split(".")[-1]
156 data = codecs.decode(last_item["content"], encoding="utf-8")
157 buf = io.StringIO()
158 buf.write(data)
159 buf.seek(0)
160 parser = PDB.get_parser(suffix)
161 structure = parser.get_structure(buf, structure_id=structure_id)
163 # TODO: We may need to lock global_state at this point?
164 load_structure(structure)
166 update_target_selection(target_selection_out)
167 update_sequence_generation(sequence_generation_out)
168 update_displayed_structure(ngl_stage)
169 update_displayed_distance_matrix(distance_matrix_out)
171 uploader.value.clear()
172 uploader._counter = 0
174 uploader.observe(handle_upload, names="value")
175 return uploader
178def create_load_distance_matrix_button(
179 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out
180):
181 uploader = FileUpload(
182 description="Load distance matrix",
183 accept=".txt",
184 multiple=False,
185 layout=Layout(width="11rem"),
186 )
188 def handle_upload(change):
189 # Keep only the last file (there must be a better way!)
190 last_item = list(change["new"].values())[-1]
192 data = codecs.decode(last_item["content"], encoding="utf-8")
194 # TODO: We may need to lock global_state at this point?
195 load_distance_matrix(data)
197 update_target_selection(target_selection_out)
198 update_sequence_generation(sequence_generation_out)
199 update_displayed_structure(ngl_stage)
200 update_displayed_distance_matrix(distance_matrix_out)
202 uploader.value.clear()
203 uploader._counter = 0
205 uploader.observe(handle_upload, names="value")
206 return uploader
209def create_load_example_buttons(
210 ngl_stage, distance_matrix_out, target_selection_out, sequence_generation_out
211):
212 examples_folder = (
213 Path(proteinsolver.__path__[0]).resolve(strict=True).joinpath("data", "inputs")
214 )
215 examples = [
216 examples_folder.joinpath(file)
217 for file in ["1n5uA03.pdb", "4beuA02.pdb", "4unuA00.pdb", "4z8jA00.pdb"]
218 ]
220 def create_activate_example_button(filename):
221 def on_example_clicked(change):
222 structure = PDB.load(filename)
223 # TODO: We may need to lock global_state at this point?
224 load_structure(structure)
225 update_target_selection(target_selection_out)
226 update_sequence_generation(sequence_generation_out)
227 update_displayed_structure(ngl_stage)
228 update_displayed_distance_matrix(distance_matrix_out)
230 button = Button(description=filename.stem, layout=Layout(width="8.25rem"))
231 button.on_click(on_example_clicked)
232 return button
234 buttons = [create_activate_example_button(example) for example in examples]
235 line = HBox(buttons, layout=Layout(flex_flow="row", align_items="center"))
236 return line