Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import torch
2import torch.nn as nn
3from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
4from torch.nn.parameter import Parameter
6from proteinsolver.nn.functional import sparse_multi_head_attention_forward
9class SparseMultiheadAttention(nn.Module):
10 r"""Allows the model to jointly attend to information
11 from different representation subspaces.
12 See reference: Attention Is All You Need
13 .. math::
14 \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
15 \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
16 Args:
17 embed_dim: total dimension of the model.
18 num_heads: parallel attention heads.
19 dropout: a Dropout layer on attn_output_weights. Default: 0.0.
20 bias: add bias as module parameter. Default: True.
21 add_bias_kv: add bias to the key and value sequences at dim=0.
22 add_zero_attn: add a new batch of zeros to the key and
23 value sequences at dim=1.
24 kdim: total number of features in key. Default: None.
25 vdim: total number of features in key. Default: None.
26 Note: if kdim and vdim are None, they will be set to embed_dim such that
27 query, key, and value have the same number of features.
28 Examples::
29 >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
30 >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
32 Source:
33 https://github.com/pytorch/pytorch/blob/70838ad08b90dc01380bf25f26efa5cfdfe4f0f4/torch/nn/modules/activation.py#L649
34 """
36 def __init__(
37 self,
38 embed_dim,
39 num_heads,
40 dropout=0.0,
41 bias=True,
42 add_bias_kv=False,
43 add_zero_attn=False,
44 kdim=None,
45 vdim=None,
46 ):
47 super().__init__()
48 self.embed_dim = embed_dim
49 self.kdim = kdim if kdim is not None else embed_dim
50 self.vdim = vdim if vdim is not None else embed_dim
51 self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
53 self.num_heads = num_heads
54 self.dropout = dropout
55 self.head_dim = embed_dim // num_heads
56 assert (
57 self.head_dim * num_heads == self.embed_dim
58 ), "embed_dim must be divisible by num_heads"
60 if self._qkv_same_embed_dim is False:
61 self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
62 self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
63 self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
64 else:
65 self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
67 if bias:
68 self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
69 else:
70 self.register_parameter("in_proj_bias", None)
71 self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
73 if add_bias_kv:
74 self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
75 self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
76 else:
77 self.bias_k = self.bias_v = None
79 self.add_zero_attn = add_zero_attn
81 self._reset_parameters()
83 def _reset_parameters(self):
84 if self._qkv_same_embed_dim:
85 xavier_uniform_(self.in_proj_weight)
86 else:
87 xavier_uniform_(self.q_proj_weight)
88 xavier_uniform_(self.k_proj_weight)
89 xavier_uniform_(self.v_proj_weight)
91 if self.in_proj_bias is not None:
92 constant_(self.in_proj_bias, 0.0)
93 constant_(self.out_proj.bias, 0.0)
94 if self.bias_k is not None:
95 xavier_normal_(self.bias_k)
96 if self.bias_v is not None:
97 xavier_normal_(self.bias_v)
99 def forward(
100 self, query, key, value, indices, key_padding_mask=None, need_weights=True, attn_mask=None
101 ):
102 r"""
103 Args:
104 query, key, value: map a query and a set of key-value pairs to an output.
105 See "Attention Is All You Need" for more details.
106 key_padding_mask: if provided, specified padding elements in the key will
107 be ignored by the attention. This is an binary mask. When the value is True,
108 the corresponding value on the attention layer will be filled with -inf.
109 need_weights: output attn_output_weights.
110 attn_mask: mask that prevents attention to certain positions. This is an additive mask
111 (i.e. the values will be added to the attention layer).
112 Shape:
113 Inputs:
114 - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
115 E is the embedding dimension.
116 - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size,
117 E is the embedding dimension.
118 - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size,
119 E is the embedding dimension.
120 - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size,
121 S is the source sequence length.
122 - attn_mask: :math:`(L, S)` where L is the target sequence length,
123 S is the source sequence length.
124 Outputs:
125 - attn_output: :math:`(L, N, E)` where L is the target sequence length,
126 N is the batch size, E is the embedding dimension.
127 - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
128 L is the target sequence length, S is the source sequence length.
129 """
130 if self._qkv_same_embed_dim:
131 return sparse_multi_head_attention_forward(
132 query,
133 key,
134 value,
135 indices,
136 self.embed_dim,
137 self.num_heads,
138 self.in_proj_weight,
139 self.in_proj_bias,
140 self.bias_k,
141 self.bias_v,
142 self.add_zero_attn,
143 self.dropout,
144 self.out_proj.weight,
145 self.out_proj.bias,
146 training=self.training,
147 key_padding_mask=key_padding_mask,
148 need_weights=need_weights,
149 attn_mask=attn_mask,
150 )
151 else:
152 return sparse_multi_head_attention_forward(
153 query,
154 key,
155 value,
156 indices,
157 self.embed_dim,
158 self.num_heads,
159 None, # set in_proj_weight = None
160 self.in_proj_bias,
161 self.bias_k,
162 self.bias_v,
163 self.add_zero_attn,
164 self.dropout,
165 self.out_proj.weight,
166 self.out_proj.bias,
167 training=self.training,
168 key_padding_mask=key_padding_mask,
169 need_weights=need_weights,
170 attn_mask=attn_mask,
171 use_separate_proj_weight=True,
172 q_proj_weight=self.q_proj_weight,
173 k_proj_weight=self.k_proj_weight,
174 v_proj_weight=self.v_proj_weight,
175 )