class (torch.nn.Module): def forward(self, arg0_1: "f16[1, 256, 68, 68]", arg1_1: "f32[225, 256, 1, 1]", arg2_1: "f32[225]", arg3_1: "f16[1, 256, 68, 68]", arg4_1: "f16[1, 256, 68, 68]"): # File: /workspace/networks/layers/attention.py:335 in forward, code: q = q / self.T div: "f16[1, 256, 68, 68]" = torch.ops.aten.div.Tensor(arg3_1, 16.0) # File: /workspace/networks/layers/attention.py:337 in forward, code: q = q.view(-1, self.d_att, h, w) view: "f16[1, 256, 68, 68]" = torch.ops.aten.reshape.default(div, [-1, 256, 68, 68]); div = None # File: /workspace/networks/layers/attention.py:338 in forward, code: k = k.view(-1, self.d_att, h, w) view_1: "f16[1, 256, 68, 68]" = torch.ops.aten.reshape.default(arg4_1, [-1, 256, 68, 68]); arg4_1 = None # File: /workspace/networks/layers/attention.py:339 in forward, code: v = v.view(-1, self.num_head, hidden_dim, h * w) view_2: "f16[1, 1, 256, 4624]" = torch.ops.aten.reshape.default(arg0_1, [-1, 1, 256, 4624]); arg0_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:453 in _conv_forward, code: return F.conv2d(input, weight, bias, self.stride, convert_element_type_1: "f16[225, 256, 1, 1]" = torch.ops.prims.convert_element_type.default(arg1_1, torch.float16); arg1_1 = None convert_element_type: "f16[225]" = torch.ops.prims.convert_element_type.default(arg2_1, torch.float16); arg2_1 = None convolution: "f16[1, 225, 68, 68]" = torch.ops.aten.convolution.default(arg3_1, convert_element_type_1, convert_element_type, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg3_1 = convert_element_type_1 = convert_element_type = None # File: /workspace/networks/layers/attention.py:341 in forward, code: relative_emb = relative_emb.view(n, self.num_head, view_3: "f16[1, 1, 225, 4624]" = torch.ops.aten.reshape.default(convolution, [1, 1, 225, 4624]); convolution = None return (view, view_1, view_2, 68, 68, view_3)