class (torch.nn.Module): def forward(self, arg0_1: "Sym(s0)", arg1_1: "Sym(s1)", arg2_1: "f16[1, 256, s0, s1]", arg3_1: "f32[225, 256, 1, 1]", arg4_1: "f32[225]", arg5_1: "f16[1, 256, s0, s1]", arg6_1: "f16[1, 256, s0, s1]"): # File: /workspace/networks/layers/attention.py:335 in forward, code: q = q / self.T div: "f16[1, 256, s0, s1]" = torch.ops.aten.div.Tensor(arg5_1, 16.0) # File: /workspace/networks/layers/attention.py:337 in forward, code: q = q.view(-1, self.d_att, h, w) view: "f16[1, 256, s0, s1]" = torch.ops.aten.reshape.default(div, [-1, 256, arg0_1, arg1_1]); div = None # File: /workspace/networks/layers/attention.py:338 in forward, code: k = k.view(-1, self.d_att, h, w) view_1: "f16[1, 256, s0, s1]" = torch.ops.aten.reshape.default(arg6_1, [-1, 256, arg0_1, arg1_1]); arg6_1 = None # File: /workspace/networks/layers/attention.py:339 in forward, code: v = v.view(-1, self.num_head, hidden_dim, h * w) mul: "Sym(s0*s1)" = arg0_1 * arg1_1 view_2: "f16[1, 1, 256, s0*s1]" = torch.ops.aten.reshape.default(arg2_1, [-1, 1, 256, mul]); arg2_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:453 in _conv_forward, code: return F.conv2d(input, weight, bias, self.stride, convert_element_type_1: "f16[225, 256, 1, 1]" = torch.ops.prims.convert_element_type.default(arg3_1, torch.float16); arg3_1 = None convert_element_type: "f16[225]" = torch.ops.prims.convert_element_type.default(arg4_1, torch.float16); arg4_1 = None convolution: "f16[1, 225, s0, s1]" = torch.ops.aten.convolution.default(arg5_1, convert_element_type_1, convert_element_type, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg5_1 = convert_element_type_1 = convert_element_type = None # File: /workspace/networks/layers/attention.py:341 in forward, code: relative_emb = relative_emb.view(n, self.num_head, view_3: "f16[1, 1, 225, s0*s1]" = torch.ops.aten.reshape.default(convolution, [1, 1, 225, mul]); convolution = mul = None return (view, view_1, view_2, arg0_1, arg1_1, view_3)