class (torch.nn.Module): def forward(self, arg0_1: "Sym(s0)", arg1_1: "f16[s0, 1, 256]", arg2_1: "f16[s0, 1, 256]", arg3_1: "f32[s0, 1, 256]", arg4_1: "f32[256]", arg5_1: "f32[256]", arg6_1: "f32[1024, 256]", arg7_1: "f32[1024]", arg8_1: "Sym(s3)", arg9_1: "Sym(s4)", arg10_1: "f32[1024]", arg11_1: "f32[1024]", arg12_1: "f32[1024, 1, 5, 5]", arg13_1: "f32[256, 1024]", arg14_1: "f32[256]"): # File: /workspace/networks/layers/transformer.py:839 in torch_dynamo_resume_in_forward_at_836, code: tgt = tgt + self.droppath(tgt2 + tgt3) add: "f16[s0, 1, 256]" = torch.ops.aten.add.Tensor(arg2_1, arg1_1); arg2_1 = arg1_1 = None add_1: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(arg3_1, add); arg3_1 = add = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( clone: "f32[s0, 1, 256]" = torch.ops.aten.clone.default(add_1, memory_format = torch.contiguous_format) var_mean = torch.ops.aten.var_mean.correction(clone, [2], correction = 0, keepdim = True) getitem: "f32[s0, 1, 1]" = var_mean[0] getitem_1: "f32[s0, 1, 1]" = var_mean[1]; var_mean = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type: "f16[1024]" = torch.ops.prims.convert_element_type.default(arg7_1, torch.float16); arg7_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( sub: "f32[s0, 1, 256]" = torch.ops.aten.sub.Tensor(clone, getitem_1); clone = getitem_1 = None add_2: "f32[s0, 1, 1]" = torch.ops.aten.add.Tensor(getitem, 1e-05); getitem = None rsqrt: "f32[s0, 1, 1]" = torch.ops.aten.rsqrt.default(add_2); add_2 = None mul: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = rsqrt = None mul_1: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(mul, arg4_1); mul = arg4_1 = None add_3: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(mul_1, arg5_1); mul_1 = arg5_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_2: "f16[s0, 1, 256]" = torch.ops.prims.convert_element_type.default(add_3, torch.float16); add_3 = None view: "f16[s0, 256]" = torch.ops.aten.reshape.default(convert_element_type_2, [arg0_1, 256]); convert_element_type_2 = None convert_element_type_1: "f16[1024, 256]" = torch.ops.prims.convert_element_type.default(arg6_1, torch.float16); arg6_1 = None permute: "f16[256, 1024]" = torch.ops.aten.permute.default(convert_element_type_1, [1, 0]); convert_element_type_1 = None addmm: "f16[s0, 1024]" = torch.ops.aten.addmm.default(convert_element_type, view, permute); convert_element_type = view = permute = None view_1: "f16[s0, 1, 1024]" = torch.ops.aten.reshape.default(addmm, [arg0_1, 1, 1024]); addmm = arg0_1 = None # File: /workspace/networks/layers/basic.py:30 in forward, code: x = x.view(h, w, bs, c).permute(2, 3, 0, 1) view_2: "f16[s3, (s0//s3), 1, 1024]" = torch.ops.aten.reshape.default(view_1, [arg8_1, arg9_1, 1, 1024]); view_1 = None permute_1: "f16[1, 1024, s3, (s0//s3)]" = torch.ops.aten.permute.default(view_2, [2, 3, 0, 1]) # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:287 in forward, code: return F.group_norm( convert_element_type_6: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.prims.convert_element_type.default(permute_1, torch.float32); permute_1 = None clone_1: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.clone.default(convert_element_type_6, memory_format = torch.contiguous_format); convert_element_type_6 = None sym_size_int_1: "Sym((s0//s3))" = torch.ops.aten.sym_size.int(view_2, 1); view_2 = None mul_2: "Sym(s3*((s0//s3)))" = arg8_1 * sym_size_int_1 view_3: "f32[1, 32, 32, s3*((s0//s3))]" = torch.ops.aten.reshape.default(clone_1, [1, 32, 32, mul_2]); clone_1 = mul_2 = None var_mean_1 = torch.ops.aten.var_mean.correction(view_3, [2, 3], correction = 0, keepdim = True) getitem_2: "f32[1, 32, 1, 1]" = var_mean_1[0] getitem_3: "f32[1, 32, 1, 1]" = var_mean_1[1]; var_mean_1 = None sub_1: "f32[1, 32, 32, s3*((s0//s3))]" = torch.ops.aten.sub.Tensor(view_3, getitem_3); view_3 = getitem_3 = None add_4: "f32[1, 32, 1, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-05); getitem_2 = None rsqrt_1: "f32[1, 32, 1, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None mul_3: "f32[1, 32, 32, s3*((s0//s3))]" = torch.ops.aten.mul.Tensor(sub_1, rsqrt_1); sub_1 = rsqrt_1 = None view_4: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.reshape.default(mul_3, [1, 1024, arg8_1, sym_size_int_1]); mul_3 = sym_size_int_1 = None unsqueeze_3: "f32[1, 1024]" = torch.ops.aten.unsqueeze.default(arg10_1, 0); arg10_1 = None unsqueeze_4: "f32[1, 1024, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2); unsqueeze_3 = None unsqueeze_5: "f32[1, 1024, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3); unsqueeze_4 = None mul_4: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.mul.Tensor(view_4, unsqueeze_5); view_4 = unsqueeze_5 = None unsqueeze: "f32[1, 1024]" = torch.ops.aten.unsqueeze.default(arg11_1, 0); arg11_1 = None unsqueeze_1: "f32[1, 1024, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze, 2); unsqueeze = None unsqueeze_2: "f32[1, 1024, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3); unsqueeze_1 = None add_5: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.add.Tensor(mul_4, unsqueeze_2); mul_4 = unsqueeze_2 = None # File: /workspace/networks/layers/basic.py:32 in forward, code: x = F.gelu(x) mul_5: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.mul.Tensor(add_5, 0.5) mul_6: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.mul.Tensor(add_5, 0.7071067811865476); add_5 = None erf: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.erf.default(mul_6); mul_6 = None add_6: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.add.Tensor(erf, 1); erf = None mul_7: "f32[1, 1024, s3, (s0//s3)]" = torch.ops.aten.mul.Tensor(mul_5, add_6); mul_5 = add_6 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/conv.py:453 in _conv_forward, code: return F.conv2d(input, weight, bias, self.stride, convert_element_type_8: "f16[1, 1024, s3, (s0//s3)]" = torch.ops.prims.convert_element_type.default(mul_7, torch.float16); mul_7 = None convert_element_type_7: "f16[1024, 1, 5, 5]" = torch.ops.prims.convert_element_type.default(arg12_1, torch.float16); arg12_1 = None convolution: "f16[1, 1024, s3, (s0//s3)]" = torch.ops.aten.convolution.default(convert_element_type_8, convert_element_type_7, None, [1, 1], [2, 2], [1, 1], False, [0, 0], 1024); convert_element_type_8 = convert_element_type_7 = None # File: /workspace/networks/layers/basic.py:34 in forward, code: x = x.view(bs, c, h * w).permute(2, 0, 1) mul_8: "Sym(s3*s4)" = arg8_1 * arg9_1; arg8_1 = arg9_1 = None view_5: "f16[1, 1024, s3*((s0//s3))]" = torch.ops.aten.reshape.default(convolution, [1, 1024, mul_8]); convolution = mul_8 = None permute_2: "f16[s3*((s0//s3)), 1, 1024]" = torch.ops.aten.permute.default(view_5, [2, 0, 1]) # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) sym_size_int_2: "Sym(s3*((s0//s3)))" = torch.ops.aten.sym_size.int(view_5, 2); view_5 = None expand: "f16[s3*((s0//s3)), 1, 1024]" = torch.ops.aten.expand.default(permute_2, [sym_size_int_2, 1, 1024]); permute_2 = None view_6: "f16[s3*((s0//s3)), 1, 1024]" = torch.ops.aten.reshape.default(expand, [sym_size_int_2, 1, 1024]); expand = None convert_element_type_10: "f16[256, 1024]" = torch.ops.prims.convert_element_type.default(arg13_1, torch.float16); arg13_1 = None permute_3: "f16[1024, 256]" = torch.ops.aten.permute.default(convert_element_type_10, [1, 0]); convert_element_type_10 = None expand_1: "f16[s3*((s0//s3)), 1024, 256]" = torch.ops.aten.expand.default(permute_3, [sym_size_int_2, 1024, 256]); permute_3 = None view_7: "f16[s3*((s0//s3)), 1024, 256]" = torch.ops.aten.reshape.default(expand_1, [sym_size_int_2, 1024, 256]); expand_1 = None bmm: "f16[s3*((s0//s3)), 1, 256]" = torch.ops.aten.bmm.default(view_6, view_7); view_6 = view_7 = None view_8: "f16[s3*((s0//s3)), 1, 256]" = torch.ops.aten.reshape.default(bmm, [sym_size_int_2, 1, 256]); bmm = sym_size_int_2 = None convert_element_type_9: "f16[256]" = torch.ops.prims.convert_element_type.default(arg14_1, torch.float16); arg14_1 = None add_7: "f16[s3*((s0//s3)), 1, 256]" = torch.ops.aten.add.Tensor(view_8, convert_element_type_9); view_8 = convert_element_type_9 = None # File: /workspace/networks/layers/transformer.py:848 in torch_dynamo_resume_in_forward_at_836, code: tgt = tgt + self.droppath(tgt2) add_8: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(add_1, add_7); add_1 = add_7 = None return (add_8,)