Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import queue
3import threading
4from enum import Enum
5from functools import partial
7import ipywidgets as widgets
8import msaview
9import torch
10from IPython.display import HTML, display
12from proteinsolver.dashboard import global_state
13from proteinsolver.dashboard.download_button import create_download_button
14from proteinsolver.dashboard.gpu_status import create_gpu_status_widget
15from proteinsolver.dashboard.ps_process import ProteinSolverProcess
16from proteinsolver.utils import AMINO_ACID_TO_IDX
18logger = logging.getLogger(__name__)
21class State(Enum):
22 ENABLED = 0
23 DISABLED = 1
26button_states = {
27 State.ENABLED: {
28 "description": "Run ProteinSolver!",
29 "icon": "check",
30 "button_style": "",
31 "tooltip": "Generate new sequences!",
32 },
33 State.DISABLED: {
34 "description": "Cancel",
35 "icon": "ban",
36 "button_style": "danger",
37 "tooltip": "Cancel!",
38 },
39}
42class ProteinSolverThread(threading.Thread):
43 def __init__(self, progress_bar, run_ps_status_out, msa_view, run_proteinsolver_button):
44 super().__init__(daemon=True)
45 self.progress_bar: widgets.IntProgress = progress_bar
46 self.run_ps_status_out: widgets.Output = run_ps_status_out
47 self.msa_view = msa_view
48 self.run_proteinsolver_button: widgets.Button = run_proteinsolver_button
50 self.data = None
51 self.num_designs = None
52 self.temperature = None
54 self._run_condition = threading.Condition()
55 self._start_new_design = False
56 self._cancel_event = threading.Event()
58 def start_new_design(self, data, num_designs, temperature) -> None:
59 with self._run_condition:
60 self.data = data
61 self.data.x = torch.tensor(
62 [AMINO_ACID_TO_IDX[aa] for aa in global_state.target_sequence], dtype=torch.long
63 )
64 self.num_designs = num_designs
65 self.temperature = temperature
66 self._start_new_design = True
67 self._cancel_event.clear()
68 assert not self.cancelled()
69 self._run_condition.notify()
71 def run(self):
72 with self._run_condition:
73 while True:
74 while not self._start_new_design:
75 self._run_condition.wait()
76 self._start_new_design = False
78 update_run_ps_button_state(self.run_proteinsolver_button, State.DISABLED)
79 self.progress_bar.value = 0
80 self.progress_bar.bar_style = ""
81 self.progress_bar.max = self.num_designs
83 global_state.generated_sequences = []
85 proc = ProteinSolverProcess(
86 net_class=global_state.net_class,
87 state_file=global_state.state_file,
88 data=self.data,
89 num_designs=self.num_designs,
90 temperature=self.temperature,
91 net_kwargs=global_state.net_kwargs,
92 )
93 proc.start()
95 success = True
96 while len(global_state.generated_sequences) < self.num_designs:
97 if self.cancelled():
98 success = False
99 proc.cancel()
100 break
102 try:
103 design = proc.output_queue.get(timeout=1.0)
104 except queue.Empty:
105 continue
107 if isinstance(design, Exception):
108 logger.error(f"Encountered an exception: ({type(design)} - {design}).")
109 self.run_ps_status_out.append_stderr(
110 f"Encountered an exception: ({type(design)} - {design})."
111 )
112 success = False
113 proc.cancel()
114 break
116 global_state.generated_sequences.append(design)
117 self.progress_bar.value += 1
118 if (
119 len(global_state.generated_sequences)
120 % max(1, len(global_state.generated_sequences) // 5)
121 == 0
122 ):
123 self.msa_view.value = [
124 {"id": seq.id, "name": seq.name, "seq": seq.seq}
125 for seq in reversed(global_state.generated_sequences[-100:])
126 ]
128 if success:
129 self.progress_bar.bar_style = "success"
130 else:
131 self.progress_bar.bar_style = "danger"
133 if (
134 len(global_state.generated_sequences)
135 % max(1, len(global_state.generated_sequences) // 5)
136 == 0
137 ):
138 self.msa_view.value = [
139 {"id": seq.id, "name": seq.name, "seq": seq.seq}
140 for seq in reversed(global_state.generated_sequences[-100:])
141 ]
143 proc.join()
144 update_run_ps_button_state(self.run_proteinsolver_button, State.ENABLED)
146 def cancel(self):
147 self._cancel_event.set()
149 def cancelled(self):
150 return self._cancel_event.is_set()
153def update_run_ps_button_state(run_ps_button: widgets.Button, state: State):
154 run_ps_button.description = button_states[state]["description"]
155 run_ps_button.icon = button_states[state]["icon"]
156 run_ps_button.button_style = button_states[state]["button_style"]
157 run_ps_button.tooltip = button_states[state]["tooltip"]
160def on_run_ps_button_clicked(run_ps_button, num_designs_field, temperature_factor_field):
161 if run_ps_button.description == button_states[State.ENABLED]["description"]:
162 update_run_ps_button_state(run_ps_button, State.DISABLED)
163 global_state.proteinsolver_thread.cancel()
164 global_state.proteinsolver_thread.start_new_design(
165 global_state.data, num_designs_field.value, temperature_factor_field.value
166 )
167 else:
168 assert run_ps_button.description == button_states[State.DISABLED]["description"]
169 global_state.proteinsolver_thread.cancel()
170 update_run_ps_button_state(run_ps_button, State.ENABLED)
173def update_sequence_generation(sequence_generation_out):
174 if global_state.view_is_initialized:
175 return
176 sequence_generation_out.clear_output(wait=True)
177 html_string = (
178 '<p class="myheading" style="margin-top: 3rem">'
179 "3. Run ProteinSolver to generate new designs"
180 "</p>"
181 )
182 sequence_generation_widget = create_sequence_generation_widget()
183 with sequence_generation_out:
184 display(HTML(html_string))
185 display(sequence_generation_widget)
186 global_state.view_is_initialized = True
189def create_sequence_generation_widget():
190 num_designs_field = widgets.BoundedIntText(
191 value=100,
192 min=1,
193 max=20_000,
194 step=1,
195 disabled=False,
196 layout=widgets.Layout(width="100px"),
197 )
199 temperature_factor_field = widgets.BoundedFloatText(
200 value=1.0,
201 min=0.0001,
202 max=100.0,
203 step=0.0001,
204 disabled=False,
205 layout=widgets.Layout(width="95px"),
206 )
208 run_ps_button = widgets.Button(layout=widgets.Layout(width="auto"))
209 update_run_ps_button_state(run_ps_button, State.ENABLED)
210 run_ps_button.on_click(
211 partial(
212 on_run_ps_button_clicked,
213 num_designs_field=num_designs_field,
214 temperature_factor_field=temperature_factor_field,
215 )
216 )
218 run_ps_status_out = widgets.Output(layout=widgets.Layout(height="75px"))
220 progress_bar = widgets.IntProgress(
221 value=0,
222 min=0,
223 max=100,
224 step=1,
225 bar_style="", # 'success', 'info', 'warning', 'danger' or ''
226 orientation="horizontal",
227 layout=widgets.Layout(width="auto", height="15px"),
228 )
230 msa_view = msaview.MSAView()
232 # if global_state.proteinsolver_thread is None:
233 global_state.proteinsolver_thread = ProteinSolverThread(
234 progress_bar, run_ps_status_out, msa_view, run_ps_button
235 )
236 global_state.proteinsolver_thread.start()
238 # The remaining widgets are stateless with respect to sequence generation
240 gpu_utilization_widget, gpu_error_message = create_gpu_status_widget()
241 gpu_status_out = widgets.Output(layout=widgets.Layout(height="75px"))
242 if gpu_error_message:
243 gpu_status_out.append_stdout(f"<p>GPU monitoring not available ({gpu_error_message}).</p>")
245 download_button = create_download_button(global_state.output_folder)
247 # Put everything together
249 left_panel = widgets.VBox(
250 [
251 widgets.VBox(
252 [
253 widgets.HBox(
254 [
255 widgets.Label("Number of sequences:", layout={"width": "145px"}),
256 num_designs_field,
257 ]
258 ),
259 widgets.HBox(
260 [
261 widgets.Label("Temperature factor:", layout={"width": "145px"}),
262 temperature_factor_field,
263 ]
264 ),
265 run_ps_button,
266 ]
267 ),
268 run_ps_status_out,
269 gpu_utilization_widget,
270 gpu_status_out,
271 download_button,
272 ],
273 layout=widgets.Layout(
274 flex_flow="column nowrap",
275 justify_content="flex-start",
276 width="240px",
277 margin="0px 20px 0px 0px",
278 ),
279 )
280 right_panel = widgets.VBox(
281 [progress_bar, msa_view], layout=widgets.Layout(width="auto", flex="1 1 auto")
282 )
284 return widgets.HBox([left_panel, right_panel], layout=widgets.Layout(flex_flow="row nowrap"))