class (torch.nn.Module): def forward(self, arg0_1: "f32[256]", arg1_1: "f32[256]", arg2_1: "Sym(s0)", arg3_1: "f32[s0, 1, 256]", arg4_1: "f32[s0, 1, 256]", arg5_1: "Sym(s4)", arg6_1: "Sym(s5)", arg7_1: "f32[256, 256]", arg8_1: "f32[256]", arg9_1: "f32[256, 256]", arg10_1: "f32[256]", arg11_1: "f32[256, 256]", arg12_1: "f32[256]", arg13_1: "Sym(s6)", arg14_1: "f32[256, 256]", arg15_1: "f32[256]", arg16_1: "f32[256]", arg17_1: "f32[256]", arg18_1: "f32[512, 256]", arg19_1: "f32[512]", arg20_1: "Sym(s7)", arg21_1: "Sym(s8)", arg22_1: "Sym(s10)", arg23_1: "f16[s10, 1, 256]", arg24_1: "f16[s10, 1, 256]", arg25_1: "f32[256, 256]", arg26_1: "f32[256]"): # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( var_mean = torch.ops.aten.var_mean.correction(arg3_1, [2], correction = 0, keepdim = True) getitem: "f32[s0, 1, 1]" = var_mean[0] getitem_1: "f32[s0, 1, 1]" = var_mean[1]; var_mean = None add: "f32[s0, 1, 1]" = torch.ops.aten.add.Tensor(getitem, 1e-05); getitem = None rsqrt: "f32[s0, 1, 1]" = torch.ops.aten.rsqrt.default(add); add = None sub: "f32[s0, 1, 256]" = torch.ops.aten.sub.Tensor(arg3_1, getitem_1); getitem_1 = None mul: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = rsqrt = None mul_1: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(mul, arg0_1); mul = arg0_1 = None add_1: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(mul_1, arg1_1); mul_1 = arg1_1 = None # File: /workspace/networks/layers/transformer.py:752 in with_pos_embed, code: return tensor if pos is None else tensor + pos add_2: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(add_1, arg4_1); arg4_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type: "f16[256]" = torch.ops.prims.convert_element_type.default(arg8_1, torch.float16); arg8_1 = None convert_element_type_1: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg7_1, torch.float16); arg7_1 = None convert_element_type_2: "f16[s0, 1, 256]" = torch.ops.prims.convert_element_type.default(add_2, torch.float16) view: "f16[s0, 256]" = torch.ops.aten.view.default(convert_element_type_2, [arg2_1, 256]); convert_element_type_2 = None permute: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_1, [1, 0]); convert_element_type_1 = None addmm: "f16[s0, 256]" = torch.ops.aten.addmm.default(convert_element_type, view, permute); convert_element_type = view = permute = None view_1: "f16[s0, 1, 256]" = torch.ops.aten.view.default(addmm, [arg2_1, 1, 256]); addmm = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_6: "f16[256]" = torch.ops.prims.convert_element_type.default(arg10_1, torch.float16); arg10_1 = None convert_element_type_7: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg9_1, torch.float16); arg9_1 = None convert_element_type_8: "f16[s0, 1, 256]" = torch.ops.prims.convert_element_type.default(add_2, torch.float16); add_2 = None view_2: "f16[s0, 256]" = torch.ops.aten.view.default(convert_element_type_8, [arg2_1, 256]); convert_element_type_8 = None permute_1: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_7, [1, 0]); convert_element_type_7 = None addmm_1: "f16[s0, 256]" = torch.ops.aten.addmm.default(convert_element_type_6, view_2, permute_1); convert_element_type_6 = view_2 = permute_1 = None view_3: "f16[s0, 1, 256]" = torch.ops.aten.view.default(addmm_1, [arg2_1, 1, 256]); addmm_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_12: "f16[256]" = torch.ops.prims.convert_element_type.default(arg12_1, torch.float16); arg12_1 = None convert_element_type_13: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg11_1, torch.float16); arg11_1 = None convert_element_type_14: "f16[s0, 1, 256]" = torch.ops.prims.convert_element_type.default(add_1, torch.float16); add_1 = None view_4: "f16[s0, 256]" = torch.ops.aten.view.default(convert_element_type_14, [arg2_1, 256]); convert_element_type_14 = None permute_2: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_13, [1, 0]); convert_element_type_13 = None addmm_2: "f16[s0, 256]" = torch.ops.aten.addmm.default(convert_element_type_12, view_4, permute_2); convert_element_type_12 = view_4 = permute_2 = None view_5: "f16[s0, 1, 256]" = torch.ops.aten.view.default(addmm_2, [arg2_1, 1, 256]); addmm_2 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div: "f16[s0, 1, 256]" = torch.ops.aten.div.Tensor(view_1, 5.656854249492381); view_1 = None # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_6: "f16[s0, 1, s4, (256//s4)]" = torch.ops.aten.view.default(div, [-1, 1, arg5_1, arg13_1]); div = None permute_3: "f16[1, s4, s0, (256//s4)]" = torch.ops.aten.permute.default(view_6, [1, 2, 0, 3]) # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_7: "f16[s0, 1, s4, (256//s4)]" = torch.ops.aten.view.default(view_3, [-1, 1, arg5_1, arg13_1]); view_3 = arg13_1 = None permute_4: "f16[1, s4, (256//s4), s0]" = torch.ops.aten.permute.default(view_7, [1, 2, 3, 0]) # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_8: "f16[s0, 1, s4, (256//s4)]" = torch.ops.aten.view.default(view_5, [-1, 1, arg5_1, arg6_1]); view_5 = arg6_1 = None permute_5: "f16[1, s4, s0, (256//s4)]" = torch.ops.aten.permute.default(view_8, [1, 2, 0, 3]) # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y sym_size_int_1: "Sym((256//s4))" = torch.ops.aten.sym_size.int(view_6, 3); view_6 = None expand: "f16[1, s4, s0, (256//s4)]" = torch.ops.aten.expand.default(permute_3, [1, arg5_1, arg2_1, sym_size_int_1]); permute_3 = None view_9: "f16[s4, s0, (256//s4)]" = torch.ops.aten.view.default(expand, [arg5_1, arg2_1, sym_size_int_1]); expand = sym_size_int_1 = None sym_size_int_2: "Sym((256//s4))" = torch.ops.aten.sym_size.int(view_7, 3); view_7 = None expand_1: "f16[1, s4, (256//s4), s0]" = torch.ops.aten.expand.default(permute_4, [1, arg5_1, sym_size_int_2, arg2_1]); permute_4 = None view_10: "f16[s4, (256//s4), s0]" = torch.ops.aten.view.default(expand_1, [arg5_1, sym_size_int_2, arg2_1]); expand_1 = sym_size_int_2 = None bmm: "f16[s4, s0, s0]" = torch.ops.aten.bmm.default(view_9, view_10); view_9 = view_10 = None view_11: "f16[1, s4, s0, s0]" = torch.ops.aten.view.default(bmm, [1, arg5_1, arg2_1, arg2_1]); bmm = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_20: "f32[1, s4, s0, s0]" = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None amax: "f32[1, s4, s0, 1]" = torch.ops.aten.amax.default(convert_element_type_20, [-1], True) sub_1: "f32[1, s4, s0, s0]" = torch.ops.aten.sub.Tensor(convert_element_type_20, amax); convert_element_type_20 = amax = None exp: "f32[1, s4, s0, s0]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, s4, s0, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) div_1: "f32[1, s4, s0, s0]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) clone: "f32[1, s4, s0, s0]" = torch.ops.aten.clone.default(div_1); div_1 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_21: "f16[1, s4, s0, s0]" = torch.ops.prims.convert_element_type.default(clone, torch.float16); clone = None expand_2: "f16[1, s4, s0, s0]" = torch.ops.aten.expand.default(convert_element_type_21, [1, arg5_1, arg2_1, arg2_1]); convert_element_type_21 = None view_12: "f16[s4, s0, s0]" = torch.ops.aten.view.default(expand_2, [arg5_1, arg2_1, arg2_1]); expand_2 = None sym_size_int_4: "Sym((256//s4))" = torch.ops.aten.sym_size.int(view_8, 3); view_8 = None expand_3: "f16[1, s4, s0, (256//s4)]" = torch.ops.aten.expand.default(permute_5, [1, arg5_1, arg2_1, sym_size_int_4]); permute_5 = None view_13: "f16[s4, s0, (256//s4)]" = torch.ops.aten.view.default(expand_3, [arg5_1, arg2_1, sym_size_int_4]); expand_3 = None bmm_1: "f16[s4, s0, (256//s4)]" = torch.ops.aten.bmm.default(view_12, view_13); view_12 = view_13 = None view_14: "f16[1, s4, s0, (256//s4)]" = torch.ops.aten.view.default(bmm_1, [1, arg5_1, arg2_1, sym_size_int_4]); bmm_1 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_6: "f16[s0, 1, s4, (256//s4)]" = torch.ops.aten.permute.default(view_14, [2, 0, 1, 3]); view_14 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) clone_1: "f16[s0, 1, s4, (256//s4)]" = torch.ops.aten.clone.default(permute_6, memory_format = torch.contiguous_format); permute_6 = None mul_2: "Sym(s0*s4)" = arg2_1 * arg5_1; arg5_1 = None mul_3: "Sym(s0*s4*((256//s4)))" = mul_2 * sym_size_int_4; mul_2 = sym_size_int_4 = None floordiv: "Sym(((s0*s4*((256//s4)))//256))" = mul_3 // 256; mul_3 = None view_15: "f16[s0, 1, s4*((256//s4))]" = torch.ops.aten.view.default(clone_1, [floordiv, 1, 256]); clone_1 = floordiv = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_24: "f16[256]" = torch.ops.prims.convert_element_type.default(arg15_1, torch.float16); arg15_1 = None convert_element_type_25: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg14_1, torch.float16); arg14_1 = None sym_size_int_6: "Sym(s4*((256//s4)))" = torch.ops.aten.sym_size.int(view_15, 2) view_16: "f16[s0, s4*((256//s4))]" = torch.ops.aten.view.default(view_15, [arg2_1, sym_size_int_6]); view_15 = sym_size_int_6 = None permute_7: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_25, [1, 0]); convert_element_type_25 = None addmm_3: "f16[s0, 256]" = torch.ops.aten.addmm.default(convert_element_type_24, view_16, permute_7); convert_element_type_24 = view_16 = permute_7 = None view_17: "f16[s0, 1, 256]" = torch.ops.aten.view.default(addmm_3, [arg2_1, 1, 256]); addmm_3 = None # File: /workspace/networks/layers/transformer.py:772 in forward, code: tgt = tgt + self.droppath(tgt2) add_3: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(arg3_1, view_17); arg3_1 = view_17 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( var_mean_1 = torch.ops.aten.var_mean.correction(add_3, [2], correction = 0, keepdim = True) getitem_2: "f32[s0, 1, 1]" = var_mean_1[0] getitem_3: "f32[s0, 1, 1]" = var_mean_1[1]; var_mean_1 = None add_4: "f32[s0, 1, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-05); getitem_2 = None rsqrt_1: "f32[s0, 1, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None sub_2: "f32[s0, 1, 256]" = torch.ops.aten.sub.Tensor(add_3, getitem_3); getitem_3 = None mul_4: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(sub_2, rsqrt_1); sub_2 = rsqrt_1 = None mul_5: "f32[s0, 1, 256]" = torch.ops.aten.mul.Tensor(mul_4, arg16_1); mul_4 = arg16_1 = None add_5: "f32[s0, 1, 256]" = torch.ops.aten.add.Tensor(mul_5, arg17_1); mul_5 = arg17_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_29: "f16[512]" = torch.ops.prims.convert_element_type.default(arg19_1, torch.float16); arg19_1 = None convert_element_type_30: "f16[512, 256]" = torch.ops.prims.convert_element_type.default(arg18_1, torch.float16); arg18_1 = None convert_element_type_31: "f16[s0, 1, 256]" = torch.ops.prims.convert_element_type.default(add_5, torch.float16) view_18: "f16[s0, 256]" = torch.ops.aten.view.default(convert_element_type_31, [arg2_1, 256]); convert_element_type_31 = None permute_8: "f16[256, 512]" = torch.ops.aten.permute.default(convert_element_type_30, [1, 0]); convert_element_type_30 = None addmm_4: "f16[s0, 512]" = torch.ops.aten.addmm.default(convert_element_type_29, view_18, permute_8); convert_element_type_29 = view_18 = permute_8 = None view_19: "f16[s0, 1, 512]" = torch.ops.aten.view.default(addmm_4, [arg2_1, 1, 512]); addmm_4 = None # File: /workspace/networks/layers/transformer.py:778 in forward, code: curr_QV = torch.split(curr_QV, self.d_model, dim=2) split = torch.ops.aten.split.Tensor(view_19, 256, 2); view_19 = None getitem_4: "f16[s0, 1, 256]" = split[0] getitem_5: "f16[s0, 1, 256]" = split[1]; split = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_20: "f16[s7, (s0//s7), 1, 256]" = torch.ops.aten.view.default(getitem_4, [arg20_1, arg21_1, 1, 256]); arg20_1 = arg21_1 = None permute_9: "f16[1, 256, s7, (s0//s7)]" = torch.ops.aten.permute.default(view_20, [2, 3, 0, 1]); view_20 = None clone_2: "f16[1, 256, s7, (s0//s7)]" = torch.ops.aten.clone.default(permute_9, memory_format = torch.contiguous_format); permute_9 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div_2: "f16[s0, 1, 256]" = torch.ops.aten.div.Tensor(getitem_4, 16.0) # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_21: "f16[s0, 1, 1, 256]" = torch.ops.aten.view.default(div_2, [-1, 1, 1, 256]); div_2 = None permute_10: "f16[1, 1, s0, 256]" = torch.ops.aten.permute.default(view_21, [1, 2, 0, 3]); view_21 = None # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_22: "f16[s10, 1, 1, 256]" = torch.ops.aten.view.default(arg23_1, [-1, 1, 1, 256]); arg23_1 = None permute_11: "f16[1, 1, 256, s10]" = torch.ops.aten.permute.default(view_22, [1, 2, 3, 0]); view_22 = None # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_23: "f16[s10, 1, 1, 256]" = torch.ops.aten.view.default(arg24_1, [-1, 1, 1, 256]); arg24_1 = None permute_12: "f16[1, 1, s10, 256]" = torch.ops.aten.permute.default(view_23, [1, 2, 0, 3]); view_23 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand_4: "f16[1, 1, s0, 256]" = torch.ops.aten.expand.default(permute_10, [1, 1, arg2_1, 256]); permute_10 = None view_24: "f16[1, s0, 256]" = torch.ops.aten.view.default(expand_4, [1, arg2_1, 256]); expand_4 = None expand_5: "f16[1, 1, 256, s10]" = torch.ops.aten.expand.default(permute_11, [1, 1, 256, arg22_1]); permute_11 = None view_25: "f16[1, 256, s10]" = torch.ops.aten.view.default(expand_5, [1, 256, arg22_1]); expand_5 = None bmm_2: "f16[1, s0, s10]" = torch.ops.aten.bmm.default(view_24, view_25); view_24 = view_25 = None view_26: "f16[1, 1, s0, s10]" = torch.ops.aten.view.default(bmm_2, [1, 1, arg2_1, arg22_1]); bmm_2 = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_37: "f32[1, 1, s0, s10]" = torch.ops.prims.convert_element_type.default(view_26, torch.float32); view_26 = None amax_1: "f32[1, 1, s0, 1]" = torch.ops.aten.amax.default(convert_element_type_37, [-1], True) sub_3: "f32[1, 1, s0, s10]" = torch.ops.aten.sub.Tensor(convert_element_type_37, amax_1); convert_element_type_37 = amax_1 = None exp_1: "f32[1, 1, s0, s10]" = torch.ops.aten.exp.default(sub_3); sub_3 = None sum_2: "f32[1, 1, s0, 1]" = torch.ops.aten.sum.dim_IntList(exp_1, [-1], True) div_3: "f32[1, 1, s0, s10]" = torch.ops.aten.div.Tensor(exp_1, sum_2); exp_1 = sum_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) clone_3: "f32[1, 1, s0, s10]" = torch.ops.aten.clone.default(div_3); div_3 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_38: "f16[1, 1, s0, s10]" = torch.ops.prims.convert_element_type.default(clone_3, torch.float16); clone_3 = None expand_6: "f16[1, 1, s0, s10]" = torch.ops.aten.expand.default(convert_element_type_38, [1, 1, arg2_1, arg22_1]); convert_element_type_38 = None view_27: "f16[1, s0, s10]" = torch.ops.aten.view.default(expand_6, [1, arg2_1, arg22_1]); expand_6 = None expand_7: "f16[1, 1, s10, 256]" = torch.ops.aten.expand.default(permute_12, [1, 1, arg22_1, 256]); permute_12 = None view_28: "f16[1, s10, 256]" = torch.ops.aten.view.default(expand_7, [1, arg22_1, 256]); expand_7 = arg22_1 = None bmm_3: "f16[1, s0, 256]" = torch.ops.aten.bmm.default(view_27, view_28); view_27 = view_28 = None view_29: "f16[1, 1, s0, 256]" = torch.ops.aten.view.default(bmm_3, [1, 1, arg2_1, 256]); bmm_3 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_13: "f16[s0, 1, 1, 256]" = torch.ops.aten.permute.default(view_29, [2, 0, 1, 3]); view_29 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) view_30: "f16[s0, 1, 256]" = torch.ops.aten.view.default(permute_13, [-1, 1, 256]); permute_13 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_41: "f16[256]" = torch.ops.prims.convert_element_type.default(arg26_1, torch.float16); arg26_1 = None convert_element_type_42: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg25_1, torch.float16); arg25_1 = None view_31: "f16[s0, 256]" = torch.ops.aten.view.default(view_30, [arg2_1, 256]); view_30 = None permute_14: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_42, [1, 0]); convert_element_type_42 = None addmm_5: "f16[s0, 256]" = torch.ops.aten.addmm.default(convert_element_type_41, view_31, permute_14); convert_element_type_41 = view_31 = permute_14 = None view_32: "f16[s0, 1, 256]" = torch.ops.aten.view.default(addmm_5, [arg2_1, 1, 256]); addmm_5 = arg2_1 = None return (clone_2, add_3, add_5, view_32, getitem_4, getitem_5)