Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import torch 

2import torch.nn as nn 

3from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 

4from torch.nn.parameter import Parameter 

5 

6from proteinsolver.nn.functional import sparse_multi_head_attention_forward 

7 

8 

9class SparseMultiheadAttention(nn.Module): 

10 r"""Allows the model to jointly attend to information 

11 from different representation subspaces. 

12 See reference: Attention Is All You Need 

13 .. math:: 

14 \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 

15 \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 

16 Args: 

17 embed_dim: total dimension of the model. 

18 num_heads: parallel attention heads. 

19 dropout: a Dropout layer on attn_output_weights. Default: 0.0. 

20 bias: add bias as module parameter. Default: True. 

21 add_bias_kv: add bias to the key and value sequences at dim=0. 

22 add_zero_attn: add a new batch of zeros to the key and 

23 value sequences at dim=1. 

24 kdim: total number of features in key. Default: None. 

25 vdim: total number of features in key. Default: None. 

26 Note: if kdim and vdim are None, they will be set to embed_dim such that 

27 query, key, and value have the same number of features. 

28 Examples:: 

29 >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 

30 >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 

31 

32 Source: 

33 https://github.com/pytorch/pytorch/blob/70838ad08b90dc01380bf25f26efa5cfdfe4f0f4/torch/nn/modules/activation.py#L649 

34 """ 

35 

36 def __init__( 

37 self, 

38 embed_dim, 

39 num_heads, 

40 dropout=0.0, 

41 bias=True, 

42 add_bias_kv=False, 

43 add_zero_attn=False, 

44 kdim=None, 

45 vdim=None, 

46 ): 

47 super().__init__() 

48 self.embed_dim = embed_dim 

49 self.kdim = kdim if kdim is not None else embed_dim 

50 self.vdim = vdim if vdim is not None else embed_dim 

51 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 

52 

53 self.num_heads = num_heads 

54 self.dropout = dropout 

55 self.head_dim = embed_dim // num_heads 

56 assert ( 

57 self.head_dim * num_heads == self.embed_dim 

58 ), "embed_dim must be divisible by num_heads" 

59 

60 if self._qkv_same_embed_dim is False: 

61 self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 

62 self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 

63 self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 

64 else: 

65 self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 

66 

67 if bias: 

68 self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 

69 else: 

70 self.register_parameter("in_proj_bias", None) 

71 self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 

72 

73 if add_bias_kv: 

74 self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 

75 self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 

76 else: 

77 self.bias_k = self.bias_v = None 

78 

79 self.add_zero_attn = add_zero_attn 

80 

81 self._reset_parameters() 

82 

83 def _reset_parameters(self): 

84 if self._qkv_same_embed_dim: 

85 xavier_uniform_(self.in_proj_weight) 

86 else: 

87 xavier_uniform_(self.q_proj_weight) 

88 xavier_uniform_(self.k_proj_weight) 

89 xavier_uniform_(self.v_proj_weight) 

90 

91 if self.in_proj_bias is not None: 

92 constant_(self.in_proj_bias, 0.0) 

93 constant_(self.out_proj.bias, 0.0) 

94 if self.bias_k is not None: 

95 xavier_normal_(self.bias_k) 

96 if self.bias_v is not None: 

97 xavier_normal_(self.bias_v) 

98 

99 def forward( 

100 self, query, key, value, indices, key_padding_mask=None, need_weights=True, attn_mask=None 

101 ): 

102 r""" 

103 Args: 

104 query, key, value: map a query and a set of key-value pairs to an output. 

105 See "Attention Is All You Need" for more details. 

106 key_padding_mask: if provided, specified padding elements in the key will 

107 be ignored by the attention. This is an binary mask. When the value is True, 

108 the corresponding value on the attention layer will be filled with -inf. 

109 need_weights: output attn_output_weights. 

110 attn_mask: mask that prevents attention to certain positions. This is an additive mask 

111 (i.e. the values will be added to the attention layer). 

112 Shape: 

113 Inputs: 

114 - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 

115 E is the embedding dimension. 

116 - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, 

117 E is the embedding dimension. 

118 - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, 

119 E is the embedding dimension. 

120 - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, 

121 S is the source sequence length. 

122 - attn_mask: :math:`(L, S)` where L is the target sequence length, 

123 S is the source sequence length. 

124 Outputs: 

125 - attn_output: :math:`(L, N, E)` where L is the target sequence length, 

126 N is the batch size, E is the embedding dimension. 

127 - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 

128 L is the target sequence length, S is the source sequence length. 

129 """ 

130 if self._qkv_same_embed_dim: 

131 return sparse_multi_head_attention_forward( 

132 query, 

133 key, 

134 value, 

135 indices, 

136 self.embed_dim, 

137 self.num_heads, 

138 self.in_proj_weight, 

139 self.in_proj_bias, 

140 self.bias_k, 

141 self.bias_v, 

142 self.add_zero_attn, 

143 self.dropout, 

144 self.out_proj.weight, 

145 self.out_proj.bias, 

146 training=self.training, 

147 key_padding_mask=key_padding_mask, 

148 need_weights=need_weights, 

149 attn_mask=attn_mask, 

150 ) 

151 else: 

152 return sparse_multi_head_attention_forward( 

153 query, 

154 key, 

155 value, 

156 indices, 

157 self.embed_dim, 

158 self.num_heads, 

159 None, # set in_proj_weight = None 

160 self.in_proj_bias, 

161 self.bias_k, 

162 self.bias_v, 

163 self.add_zero_attn, 

164 self.dropout, 

165 self.out_proj.weight, 

166 self.out_proj.bias, 

167 training=self.training, 

168 key_padding_mask=key_padding_mask, 

169 need_weights=need_weights, 

170 attn_mask=attn_mask, 

171 use_separate_proj_weight=True, 

172 q_proj_weight=self.q_proj_weight, 

173 k_proj_weight=self.k_proj_weight, 

174 v_proj_weight=self.v_proj_weight, 

175 )