diff --git a/hypercore/nn/attention/lorentz_former_conv.py b/hypercore/nn/attention/lorentz_former_conv.py index 37015c2..a032149 100644 --- a/hypercore/nn/attention/lorentz_former_conv.py +++ b/hypercore/nn/attention/lorentz_former_conv.py @@ -53,12 +53,17 @@ def __init__(self, manifold: Lorentz, in_channels, out_channels, num_heads, use self.scale = nn.Parameter(torch.tensor([math.sqrt(num_heads * out_channels)])) self.bias = nn.Parameter(torch.zeros(())) if self.attention_type=='linear_focused': - self.v_map_mlp = nn.Linear(in_channels - 1, out_channels, bias=True) + _v_space = (out_channels - 2) if use_weight else (out_channels - 1) + self.v_map_mlp = nn.Linear(_v_space, _v_space, bias=True) self.norm_scale = nn.Parameter(torch.ones(())) self.power_k = power_k self.trans_heads_concat = trans_heads_concat if self.trans_heads_concat: - self.final_linear = nn.Linear(self.num_heads * (self.out_channels), self.num_heads * self.out_channels - 1) + if self.attention_type == 'linear_focused': + _v_space = (out_channels - 2) if use_weight else (out_channels - 1) + self.final_linear = nn.Linear(self.num_heads * _v_space, self.num_heads * self.out_channels - 1) + else: + self.final_linear = nn.Linear(self.num_heads * self.out_channels, self.num_heads * self.out_channels - 1) self.normalize = normalize @staticmethod @@ -175,7 +180,7 @@ def linear_focus_attention(self, hyp_qs, hyp_ks, hyp_vs, output_attentions=False attn_output = attn_output + vss # preserve its rank, [B, N, H, D] if self.trans_heads_concat: - attn_output = self.final_linear(attn_output.reshape(attn_output.size(0), -1, self.num_heads * self.out_channels)) + attn_output = self.final_linear(attn_output.reshape(attn_output.size(0), attn_output.size(1), -1)) else: attn_output = attn_output.mean(dim=1)