マルチモーダル学習におけるAttention機構について
解決したいこと
テキストと音声のマルチモーダル学習を行っているのですが、それぞれ異なるモデルで特徴抽出をした後に、concatにより統合し、600次元の特徴を出力しています。
この後、Self-Attention(Scaled Dot-Product Attention)のようなAttention機構を適用し、特徴間の相互作用を強調したいと考えています。
以下のコードを作成したのですが、間違っていますか。
該当するソースコード
class MultimodalAttention(nn.Module):
def __init__(self, feature_dim):
super(MultimodalAttention, self).__init__()
self.query = nn.Linear(feature_dim, feature_dim)
self.key = nn.Linear(feature_dim, feature_dim)
self.value = nn.Linear(feature_dim, feature_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, image_features, text_features):
"""
Args:
image_features: 画像の特徴 (B, F)
text_features: テキストの特徴 (B, F)
Returns:
attended_features: 統合された特徴 (B, F)
"""
Q = self.query(image_features) # (B, F)
K = self.key(text_features) # (B, F)
V = self.value(text_features) # (B, F)
# Scaled Dot-Product Attention
attention_scores = self.softmax(Q @ K.transpose(-2, -1) / (K.size(-1) ** 0.5)) # (B, F, F)
attended_features = attention_scores @ V # (B, F)
return attended_features
0