Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import logging 

2import queue 

3import threading 

4from enum import Enum 

5from functools import partial 

6 

7import ipywidgets as widgets 

8import msaview 

9import torch 

10from IPython.display import HTML, display 

11 

12from proteinsolver.dashboard import global_state 

13from proteinsolver.dashboard.download_button import create_download_button 

14from proteinsolver.dashboard.gpu_status import create_gpu_status_widget 

15from proteinsolver.dashboard.ps_process import ProteinSolverProcess 

16from proteinsolver.utils import AMINO_ACID_TO_IDX 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class State(Enum): 

22 ENABLED = 0 

23 DISABLED = 1 

24 

25 

26button_states = { 

27 State.ENABLED: { 

28 "description": "Run ProteinSolver!", 

29 "icon": "check", 

30 "button_style": "", 

31 "tooltip": "Generate new sequences!", 

32 }, 

33 State.DISABLED: { 

34 "description": "Cancel", 

35 "icon": "ban", 

36 "button_style": "danger", 

37 "tooltip": "Cancel!", 

38 }, 

39} 

40 

41 

42class ProteinSolverThread(threading.Thread): 

43 def __init__(self, progress_bar, run_ps_status_out, msa_view, run_proteinsolver_button): 

44 super().__init__(daemon=True) 

45 self.progress_bar: widgets.IntProgress = progress_bar 

46 self.run_ps_status_out: widgets.Output = run_ps_status_out 

47 self.msa_view = msa_view 

48 self.run_proteinsolver_button: widgets.Button = run_proteinsolver_button 

49 

50 self.data = None 

51 self.num_designs = None 

52 self.temperature = None 

53 

54 self._run_condition = threading.Condition() 

55 self._start_new_design = False 

56 self._cancel_event = threading.Event() 

57 

58 def start_new_design(self, data, num_designs, temperature) -> None: 

59 with self._run_condition: 

60 self.data = data 

61 self.data.x = torch.tensor( 

62 [AMINO_ACID_TO_IDX[aa] for aa in global_state.target_sequence], dtype=torch.long 

63 ) 

64 self.num_designs = num_designs 

65 self.temperature = temperature 

66 self._start_new_design = True 

67 self._cancel_event.clear() 

68 assert not self.cancelled() 

69 self._run_condition.notify() 

70 

71 def run(self): 

72 with self._run_condition: 

73 while True: 

74 while not self._start_new_design: 

75 self._run_condition.wait() 

76 self._start_new_design = False 

77 

78 update_run_ps_button_state(self.run_proteinsolver_button, State.DISABLED) 

79 self.progress_bar.value = 0 

80 self.progress_bar.bar_style = "" 

81 self.progress_bar.max = self.num_designs 

82 

83 global_state.generated_sequences = [] 

84 

85 proc = ProteinSolverProcess( 

86 net_class=global_state.net_class, 

87 state_file=global_state.state_file, 

88 data=self.data, 

89 num_designs=self.num_designs, 

90 temperature=self.temperature, 

91 net_kwargs=global_state.net_kwargs, 

92 ) 

93 proc.start() 

94 

95 success = True 

96 while len(global_state.generated_sequences) < self.num_designs: 

97 if self.cancelled(): 

98 success = False 

99 proc.cancel() 

100 break 

101 

102 try: 

103 design = proc.output_queue.get(timeout=1.0) 

104 except queue.Empty: 

105 continue 

106 

107 if isinstance(design, Exception): 

108 logger.error(f"Encountered an exception: ({type(design)} - {design}).") 

109 self.run_ps_status_out.append_stderr( 

110 f"Encountered an exception: ({type(design)} - {design})." 

111 ) 

112 success = False 

113 proc.cancel() 

114 break 

115 

116 global_state.generated_sequences.append(design) 

117 self.progress_bar.value += 1 

118 if ( 

119 len(global_state.generated_sequences) 

120 % max(1, len(global_state.generated_sequences) // 5) 

121 == 0 

122 ): 

123 self.msa_view.value = [ 

124 {"id": seq.id, "name": seq.name, "seq": seq.seq} 

125 for seq in reversed(global_state.generated_sequences[-100:]) 

126 ] 

127 

128 if success: 

129 self.progress_bar.bar_style = "success" 

130 else: 

131 self.progress_bar.bar_style = "danger" 

132 

133 if ( 

134 len(global_state.generated_sequences) 

135 % max(1, len(global_state.generated_sequences) // 5) 

136 == 0 

137 ): 

138 self.msa_view.value = [ 

139 {"id": seq.id, "name": seq.name, "seq": seq.seq} 

140 for seq in reversed(global_state.generated_sequences[-100:]) 

141 ] 

142 

143 proc.join() 

144 update_run_ps_button_state(self.run_proteinsolver_button, State.ENABLED) 

145 

146 def cancel(self): 

147 self._cancel_event.set() 

148 

149 def cancelled(self): 

150 return self._cancel_event.is_set() 

151 

152 

153def update_run_ps_button_state(run_ps_button: widgets.Button, state: State): 

154 run_ps_button.description = button_states[state]["description"] 

155 run_ps_button.icon = button_states[state]["icon"] 

156 run_ps_button.button_style = button_states[state]["button_style"] 

157 run_ps_button.tooltip = button_states[state]["tooltip"] 

158 

159 

160def on_run_ps_button_clicked(run_ps_button, num_designs_field, temperature_factor_field): 

161 if run_ps_button.description == button_states[State.ENABLED]["description"]: 

162 update_run_ps_button_state(run_ps_button, State.DISABLED) 

163 global_state.proteinsolver_thread.cancel() 

164 global_state.proteinsolver_thread.start_new_design( 

165 global_state.data, num_designs_field.value, temperature_factor_field.value 

166 ) 

167 else: 

168 assert run_ps_button.description == button_states[State.DISABLED]["description"] 

169 global_state.proteinsolver_thread.cancel() 

170 update_run_ps_button_state(run_ps_button, State.ENABLED) 

171 

172 

173def update_sequence_generation(sequence_generation_out): 

174 if global_state.view_is_initialized: 

175 return 

176 sequence_generation_out.clear_output(wait=True) 

177 html_string = ( 

178 '<p class="myheading" style="margin-top: 3rem">' 

179 "3. Run ProteinSolver to generate new designs" 

180 "</p>" 

181 ) 

182 sequence_generation_widget = create_sequence_generation_widget() 

183 with sequence_generation_out: 

184 display(HTML(html_string)) 

185 display(sequence_generation_widget) 

186 global_state.view_is_initialized = True 

187 

188 

189def create_sequence_generation_widget(): 

190 num_designs_field = widgets.BoundedIntText( 

191 value=100, 

192 min=1, 

193 max=20_000, 

194 step=1, 

195 disabled=False, 

196 layout=widgets.Layout(width="100px"), 

197 ) 

198 

199 temperature_factor_field = widgets.BoundedFloatText( 

200 value=1.0, 

201 min=0.0001, 

202 max=100.0, 

203 step=0.0001, 

204 disabled=False, 

205 layout=widgets.Layout(width="95px"), 

206 ) 

207 

208 run_ps_button = widgets.Button(layout=widgets.Layout(width="auto")) 

209 update_run_ps_button_state(run_ps_button, State.ENABLED) 

210 run_ps_button.on_click( 

211 partial( 

212 on_run_ps_button_clicked, 

213 num_designs_field=num_designs_field, 

214 temperature_factor_field=temperature_factor_field, 

215 ) 

216 ) 

217 

218 run_ps_status_out = widgets.Output(layout=widgets.Layout(height="75px")) 

219 

220 progress_bar = widgets.IntProgress( 

221 value=0, 

222 min=0, 

223 max=100, 

224 step=1, 

225 bar_style="", # 'success', 'info', 'warning', 'danger' or '' 

226 orientation="horizontal", 

227 layout=widgets.Layout(width="auto", height="15px"), 

228 ) 

229 

230 msa_view = msaview.MSAView() 

231 

232 # if global_state.proteinsolver_thread is None: 

233 global_state.proteinsolver_thread = ProteinSolverThread( 

234 progress_bar, run_ps_status_out, msa_view, run_ps_button 

235 ) 

236 global_state.proteinsolver_thread.start() 

237 

238 # The remaining widgets are stateless with respect to sequence generation 

239 

240 gpu_utilization_widget, gpu_error_message = create_gpu_status_widget() 

241 gpu_status_out = widgets.Output(layout=widgets.Layout(height="75px")) 

242 if gpu_error_message: 

243 gpu_status_out.append_stdout(f"<p>GPU monitoring not available ({gpu_error_message}).</p>") 

244 

245 download_button = create_download_button(global_state.output_folder) 

246 

247 # Put everything together 

248 

249 left_panel = widgets.VBox( 

250 [ 

251 widgets.VBox( 

252 [ 

253 widgets.HBox( 

254 [ 

255 widgets.Label("Number of sequences:", layout={"width": "145px"}), 

256 num_designs_field, 

257 ] 

258 ), 

259 widgets.HBox( 

260 [ 

261 widgets.Label("Temperature factor:", layout={"width": "145px"}), 

262 temperature_factor_field, 

263 ] 

264 ), 

265 run_ps_button, 

266 ] 

267 ), 

268 run_ps_status_out, 

269 gpu_utilization_widget, 

270 gpu_status_out, 

271 download_button, 

272 ], 

273 layout=widgets.Layout( 

274 flex_flow="column nowrap", 

275 justify_content="flex-start", 

276 width="240px", 

277 margin="0px 20px 0px 0px", 

278 ), 

279 ) 

280 right_panel = widgets.VBox( 

281 [progress_bar, msa_view], layout=widgets.Layout(width="auto", flex="1 1 auto") 

282 ) 

283 

284 return widgets.HBox([left_panel, right_panel], layout=widgets.Layout(flex_flow="row nowrap"))