class (torch.nn.Module): def forward(self, arg0_1: "f32[256]", arg1_1: "f32[256]", arg2_1: "Sym(s13)", arg3_1: "f32[s13, 1, 256]", arg4_1: "f32[s13, 1, 256]", arg5_1: "Sym(s6)", arg6_1: "Sym(s7)", arg7_1: "f32[256, 256]", arg8_1: "f32[256]", arg9_1: "f32[256, 256]", arg10_1: "f32[256]", arg11_1: "f32[256, 256]", arg12_1: "f32[256]", arg13_1: "Sym(s8)", arg14_1: "f32[256, 256]", arg15_1: "f32[256]", arg16_1: "f32[256]", arg17_1: "f32[256]", arg18_1: "f32[512, 256]", arg19_1: "f32[512]", arg20_1: "Sym(s9)", arg21_1: "Sym(s10)", arg22_1: "f32[s13, 1, 256]", arg23_1: "f32[257, 256]", arg24_1: "f32[257]", arg25_1: "f32[256, 256]", arg26_1: "f32[256]"): # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( clone: "f32[s13, 1, 256]" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format) var_mean = torch.ops.aten.var_mean.correction(clone, [2], correction = 0, keepdim = True) getitem: "f32[s13, 1, 1]" = var_mean[0] getitem_1: "f32[s13, 1, 1]" = var_mean[1]; var_mean = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_24: "f16[256]" = torch.ops.prims.convert_element_type.default(arg15_1, torch.float16); arg15_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type: "f16[256]" = torch.ops.prims.convert_element_type.default(arg8_1, torch.float16); arg8_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( sub: "f32[s13, 1, 256]" = torch.ops.aten.sub.Tensor(clone, getitem_1); clone = getitem_1 = None add: "f32[s13, 1, 1]" = torch.ops.aten.add.Tensor(getitem, 1e-05); getitem = None rsqrt: "f32[s13, 1, 1]" = torch.ops.aten.rsqrt.default(add); add = None mul: "f32[s13, 1, 256]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = rsqrt = None mul_1: "f32[s13, 1, 256]" = torch.ops.aten.mul.Tensor(mul, arg0_1); mul = arg0_1 = None add_1: "f32[s13, 1, 256]" = torch.ops.aten.add.Tensor(mul_1, arg1_1); mul_1 = arg1_1 = None # File: /workspace/networks/layers/transformer.py:752 in with_pos_embed, code: return tensor if pos is None else tensor + pos add_2: "f32[s13, 1, 256]" = torch.ops.aten.add.Tensor(add_1, arg4_1); arg4_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_2: "f16[s13, 1, 256]" = torch.ops.prims.convert_element_type.default(add_2, torch.float16) view: "f16[s13, 256]" = torch.ops.aten.reshape.default(convert_element_type_2, [arg2_1, 256]); convert_element_type_2 = None convert_element_type_1: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg7_1, torch.float16); arg7_1 = None permute: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_1, [1, 0]); convert_element_type_1 = None # No stacktrace found for following nodes mm_default_2: "f16[s13, 256]" = torch.ops.aten.mm.default(view, permute); view = permute = None add_tensor_2: "f16[s13, 256]" = torch.ops.aten.add.Tensor(mm_default_2, convert_element_type); mm_default_2 = convert_element_type = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_1: "f16[s13, 1, 256]" = torch.ops.aten.reshape.default(add_tensor_2, [arg2_1, 1, 256]); add_tensor_2 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div: "f16[s13, 1, 256]" = torch.ops.aten.div.Tensor(view_1, 8.0); view_1 = None # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_6: "f16[s13, 1, s6, (256//s6)]" = torch.ops.aten.reshape.default(div, [-1, 1, arg5_1, arg13_1]); div = None permute_3: "f16[1, s6, s13, (256//s6)]" = torch.ops.aten.permute.default(view_6, [1, 2, 0, 3]) # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y sym_size_int_3: "Sym((256//s6))" = torch.ops.aten.sym_size.int(view_6, 3); view_6 = None expand: "f16[1, s6, s13, (256//s6)]" = torch.ops.aten.expand.default(permute_3, [1, arg5_1, arg2_1, sym_size_int_3]); permute_3 = None view_9: "f16[s6, s13, (256//s6)]" = torch.ops.aten.reshape.default(expand, [arg5_1, arg2_1, sym_size_int_3]); expand = sym_size_int_3 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_6: "f16[256]" = torch.ops.prims.convert_element_type.default(arg10_1, torch.float16); arg10_1 = None # File: /workspace/networks/layers/transformer.py:768 in forward, code: k = k[::self.global_dilation,:,:] slice_1: "f32[((s13 + 1)//2), 1, 256]" = torch.ops.aten.slice.Tensor(add_2, 0, 0, 9223372036854775807, 2); add_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_8: "f16[((s13 + 1)//2), 1, 256]" = torch.ops.prims.convert_element_type.default(slice_1, torch.float16) sym_size_int_1: "Sym(((s13 + 1)//2))" = torch.ops.aten.sym_size.int(slice_1, 0); slice_1 = None view_2: "f16[((s13 + 1)//2), 256]" = torch.ops.aten.reshape.default(convert_element_type_8, [sym_size_int_1, 256]); convert_element_type_8 = None convert_element_type_7: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg9_1, torch.float16); arg9_1 = None permute_1: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_7, [1, 0]); convert_element_type_7 = None addmm_1: "f16[((s13 + 1)//2), 256]" = torch.ops.aten.addmm.default(convert_element_type_6, view_2, permute_1); convert_element_type_6 = view_2 = permute_1 = None view_3: "f16[((s13 + 1)//2), 1, 256]" = torch.ops.aten.reshape.default(addmm_1, [sym_size_int_1, 1, 256]); addmm_1 = None # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_7: "f16[((s13 + 1)//2), 1, s6, (256//s6)]" = torch.ops.aten.reshape.default(view_3, [-1, 1, arg5_1, arg13_1]); view_3 = arg13_1 = None permute_4: "f16[1, s6, (256//s6), ((s13 + 1)//2)]" = torch.ops.aten.permute.default(view_7, [1, 2, 3, 0]) # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y sym_size_int_4: "Sym((256//s6))" = torch.ops.aten.sym_size.int(view_7, 3); view_7 = None expand_1: "f16[1, s6, (256//s6), ((s13 + 1)//2)]" = torch.ops.aten.expand.default(permute_4, [1, arg5_1, sym_size_int_4, sym_size_int_1]); permute_4 = None view_10: "f16[s6, (256//s6), ((s13 + 1)//2)]" = torch.ops.aten.reshape.default(expand_1, [arg5_1, sym_size_int_4, sym_size_int_1]); expand_1 = sym_size_int_4 = None bmm: "f16[s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.bmm.default(view_9, view_10); view_9 = view_10 = None view_11: "f16[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.reshape.default(bmm, [1, arg5_1, arg2_1, sym_size_int_1]); bmm = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_20: "f32[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None amax: "f32[1, s6, s13, 1]" = torch.ops.aten.amax.default(convert_element_type_20, [-1], True) sub_1: "f32[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.sub.Tensor(convert_element_type_20, amax); convert_element_type_20 = amax = None exp: "f32[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, s6, s13, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) div_1: "f32[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_21: "f16[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.prims.convert_element_type.default(div_1, torch.float16); div_1 = None expand_2: "f16[1, s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.expand.default(convert_element_type_21, [1, arg5_1, arg2_1, sym_size_int_1]); convert_element_type_21 = None view_12: "f16[s6, s13, ((s13 + 1)//2)]" = torch.ops.aten.reshape.default(expand_2, [arg5_1, arg2_1, sym_size_int_1]); expand_2 = sym_size_int_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_12: "f16[256]" = torch.ops.prims.convert_element_type.default(arg12_1, torch.float16); arg12_1 = None # File: /workspace/networks/layers/transformer.py:769 in forward, code: v = v[::self.global_dilation,:,:] slice_4: "f32[((s13 + 1)//2), 1, 256]" = torch.ops.aten.slice.Tensor(add_1, 0, 0, 9223372036854775807, 2); add_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_14: "f16[((s13 + 1)//2), 1, 256]" = torch.ops.prims.convert_element_type.default(slice_4, torch.float16) sym_size_int_2: "Sym(((s13 + 1)//2))" = torch.ops.aten.sym_size.int(slice_4, 0); slice_4 = None view_4: "f16[((s13 + 1)//2), 256]" = torch.ops.aten.reshape.default(convert_element_type_14, [sym_size_int_2, 256]); convert_element_type_14 = None convert_element_type_13: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg11_1, torch.float16); arg11_1 = None permute_2: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_13, [1, 0]); convert_element_type_13 = None addmm_2: "f16[((s13 + 1)//2), 256]" = torch.ops.aten.addmm.default(convert_element_type_12, view_4, permute_2); convert_element_type_12 = view_4 = permute_2 = None view_5: "f16[((s13 + 1)//2), 1, 256]" = torch.ops.aten.reshape.default(addmm_2, [sym_size_int_2, 1, 256]); addmm_2 = None # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_8: "f16[((s13 + 1)//2), 1, s6, (256//s6)]" = torch.ops.aten.reshape.default(view_5, [-1, 1, arg5_1, arg6_1]); view_5 = arg6_1 = None permute_5: "f16[1, s6, ((s13 + 1)//2), (256//s6)]" = torch.ops.aten.permute.default(view_8, [1, 2, 0, 3]) # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y sym_size_int_6: "Sym((256//s6))" = torch.ops.aten.sym_size.int(view_8, 3); view_8 = None expand_3: "f16[1, s6, ((s13 + 1)//2), (256//s6)]" = torch.ops.aten.expand.default(permute_5, [1, arg5_1, sym_size_int_2, sym_size_int_6]); permute_5 = None view_13: "f16[s6, ((s13 + 1)//2), (256//s6)]" = torch.ops.aten.reshape.default(expand_3, [arg5_1, sym_size_int_2, sym_size_int_6]); expand_3 = sym_size_int_2 = None bmm_1: "f16[s6, s13, (256//s6)]" = torch.ops.aten.bmm.default(view_12, view_13); view_12 = view_13 = None view_14: "f16[1, s6, s13, (256//s6)]" = torch.ops.aten.reshape.default(bmm_1, [1, arg5_1, arg2_1, sym_size_int_6]); bmm_1 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_6: "f16[s13, 1, s6, (256//s6)]" = torch.ops.aten.permute.default(view_14, [2, 0, 1, 3]); view_14 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) clone_2: "f16[s13, 1, s6, (256//s6)]" = torch.ops.aten.clone.default(permute_6, memory_format = torch.contiguous_format); permute_6 = None mul_2: "Sym(s13*s6)" = arg2_1 * arg5_1; arg5_1 = None mul_3: "Sym(s13*s6*((256//s6)))" = mul_2 * sym_size_int_6; mul_2 = sym_size_int_6 = None floordiv: "Sym(((s13*s6*((256//s6)))//256))" = mul_3 // 256; mul_3 = None view_15: "f16[s13, 1, s6*((256//s6))]" = torch.ops.aten.reshape.default(clone_2, [floordiv, 1, 256]); clone_2 = floordiv = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) sym_size_int_8: "Sym(s6*((256//s6)))" = torch.ops.aten.sym_size.int(view_15, 2) view_16: "f16[s13, s6*((256//s6))]" = torch.ops.aten.reshape.default(view_15, [arg2_1, sym_size_int_8]); view_15 = sym_size_int_8 = None convert_element_type_25: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg14_1, torch.float16); arg14_1 = None permute_7: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_25, [1, 0]); convert_element_type_25 = None # No stacktrace found for following nodes mm_default_1: "f16[s13, 256]" = torch.ops.aten.mm.default(view_16, permute_7); view_16 = permute_7 = None add_tensor_1: "f16[s13, 256]" = torch.ops.aten.add.Tensor(mm_default_1, convert_element_type_24); mm_default_1 = convert_element_type_24 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_17: "f16[s13, 1, 256]" = torch.ops.aten.reshape.default(add_tensor_1, [arg2_1, 1, 256]); add_tensor_1 = None # File: /workspace/networks/layers/transformer.py:772 in forward, code: tgt = tgt + self.droppath(tgt2) add_3: "f32[s13, 1, 256]" = torch.ops.aten.add.Tensor(arg3_1, view_17); arg3_1 = view_17 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( clone_3: "f32[s13, 1, 256]" = torch.ops.aten.clone.default(add_3, memory_format = torch.contiguous_format) var_mean_1 = torch.ops.aten.var_mean.correction(clone_3, [2], correction = 0, keepdim = True) getitem_2: "f32[s13, 1, 1]" = var_mean_1[0] getitem_3: "f32[s13, 1, 1]" = var_mean_1[1]; var_mean_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_29: "f16[512]" = torch.ops.prims.convert_element_type.default(arg19_1, torch.float16); arg19_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( sub_2: "f32[s13, 1, 256]" = torch.ops.aten.sub.Tensor(clone_3, getitem_3); clone_3 = getitem_3 = None add_4: "f32[s13, 1, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-05); getitem_2 = None rsqrt_1: "f32[s13, 1, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None mul_4: "f32[s13, 1, 256]" = torch.ops.aten.mul.Tensor(sub_2, rsqrt_1); sub_2 = rsqrt_1 = None mul_5: "f32[s13, 1, 256]" = torch.ops.aten.mul.Tensor(mul_4, arg16_1); mul_4 = arg16_1 = None add_5: "f32[s13, 1, 256]" = torch.ops.aten.add.Tensor(mul_5, arg17_1); mul_5 = arg17_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_31: "f16[s13, 1, 256]" = torch.ops.prims.convert_element_type.default(add_5, torch.float16) view_18: "f16[s13, 256]" = torch.ops.aten.reshape.default(convert_element_type_31, [arg2_1, 256]); convert_element_type_31 = None convert_element_type_30: "f16[512, 256]" = torch.ops.prims.convert_element_type.default(arg18_1, torch.float16); arg18_1 = None permute_8: "f16[256, 512]" = torch.ops.aten.permute.default(convert_element_type_30, [1, 0]); convert_element_type_30 = None addmm_4: "f16[s13, 512]" = torch.ops.aten.addmm.default(convert_element_type_29, view_18, permute_8); convert_element_type_29 = view_18 = permute_8 = None view_19: "f16[s13, 1, 512]" = torch.ops.aten.reshape.default(addmm_4, [arg2_1, 1, 512]); addmm_4 = None # File: /workspace/networks/layers/transformer.py:778 in forward, code: curr_QV = torch.split(curr_QV, self.d_model, dim=2) split = torch.ops.aten.split.Tensor(view_19, 256, 2); view_19 = None getitem_4: "f16[s13, 1, 256]" = split[0] getitem_5: "f16[s13, 1, 256]" = split[1]; split = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_35: "f16[257]" = torch.ops.prims.convert_element_type.default(arg24_1, torch.float16); arg24_1 = None # No stacktrace found for following nodes full_default: "f16[7]" = torch.ops.aten.full.default([7], 0, dtype = torch.float16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) cat_default: "f16[264]" = torch.ops.aten.cat.default([convert_element_type_35, full_default]); convert_element_type_35 = full_default = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_37: "f16[s13, 1, 256]" = torch.ops.prims.convert_element_type.default(arg22_1, torch.float16); arg22_1 = None view_21: "f16[s13, 256]" = torch.ops.aten.reshape.default(convert_element_type_37, [arg2_1, 256]); convert_element_type_37 = None convert_element_type_36: "f16[257, 256]" = torch.ops.prims.convert_element_type.default(arg23_1, torch.float16); arg23_1 = None permute_10: "f16[256, 257]" = torch.ops.aten.permute.default(convert_element_type_36, [1, 0]); convert_element_type_36 = None # No stacktrace found for following nodes constant_pad_nd_default: "f16[256, 264]" = torch.ops.aten.constant_pad_nd.default(permute_10, [0, 7, 0, 0]); permute_10 = None mm_default: "f16[s13, 264]" = torch.ops.aten.mm.default(view_21, constant_pad_nd_default); view_21 = constant_pad_nd_default = None add_tensor: "f16[s13, 264]" = torch.ops.aten.add.Tensor(mm_default, cat_default); mm_default = cat_default = None slice_tensor_1: "f16[s13, 257]" = torch.ops.aten.slice.Tensor(add_tensor, 1, 0, -7); add_tensor = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_22: "f16[s13, 1, 257]" = torch.ops.aten.reshape.default(slice_tensor_1, [arg2_1, 1, 257]); slice_tensor_1 = None # File: /workspace/networks/layers/transformer.py:855 in fuse_key_value_id, code: ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) split_with_sizes = torch.ops.aten.split_with_sizes.default(view_22, [1, 256], 2); view_22 = None getitem_6: "f16[s13, 1, 1]" = split_with_sizes[0] getitem_7: "f16[s13, 1, 256]" = split_with_sizes[1]; split_with_sizes = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_20: "f16[s9, (s13//s9), 1, 256]" = torch.ops.aten.reshape.default(getitem_4, [arg20_1, arg21_1, 1, 256]) permute_9: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.permute.default(view_20, [2, 3, 0, 1]); view_20 = None clone_4: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.clone.default(permute_9, memory_format = torch.contiguous_format); permute_9 = None # File: /workspace/networks/layers/transformer.py:857 in fuse_key_value_id, code: K = key.view(-1, bs, self.att_nhead, self.d_model // view_23: "f16[s13, 1, 1, 256]" = torch.ops.aten.reshape.default(getitem_4, [-1, 1, 1, 256]) # File: /workspace/networks/layers/transformer.py:858 in fuse_key_value_id, code: self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) tanh: "f16[s13, 1, 1]" = torch.ops.aten.tanh.default(getitem_6); getitem_6 = None add_6: "f16[s13, 1, 1]" = torch.ops.aten.add.Tensor(tanh, 1); tanh = None unsqueeze: "f16[s13, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(add_6, -1); add_6 = None # File: /workspace/networks/layers/transformer.py:857 in fuse_key_value_id, code: K = key.view(-1, bs, self.att_nhead, self.d_model // mul_6: "f16[s13, 1, 1, 256]" = torch.ops.aten.mul.Tensor(view_23, unsqueeze); view_23 = unsqueeze = None # File: /workspace/networks/layers/transformer.py:859 in fuse_key_value_id, code: K = K.view(-1, bs, self.d_model) view_24: "f16[s13, 1, 256]" = torch.ops.aten.reshape.default(mul_6, [-1, 1, 256]); mul_6 = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_25: "f16[s9, (s13//s9), 1, 256]" = torch.ops.aten.reshape.default(view_24, [arg20_1, arg21_1, 1, 256]) permute_11: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.permute.default(view_25, [2, 3, 0, 1]); view_25 = None clone_5: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.clone.default(permute_11, memory_format = torch.contiguous_format); permute_11 = None # File: /workspace/networks/layers/transformer.py:860 in fuse_key_value_id, code: V = value + ID_V add_7: "f16[s13, 1, 256]" = torch.ops.aten.add.Tensor(getitem_5, getitem_7); getitem_7 = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_26: "f16[s9, (s13//s9), 1, 256]" = torch.ops.aten.reshape.default(add_7, [arg20_1, arg21_1, 1, 256]) permute_12: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.permute.default(view_26, [2, 3, 0, 1]); view_26 = None clone_6: "f16[1, 256, s9, (s13//s9)]" = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format); permute_12 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_47: "f16[256]" = torch.ops.prims.convert_element_type.default(arg26_1, torch.float16); arg26_1 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div_2: "f16[s13, 1, 256]" = torch.ops.aten.div.Tensor(getitem_4, 16.0) # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_31: "f16[s13, 1, 1, 256]" = torch.ops.aten.reshape.default(div_2, [-1, 1, 1, 256]); div_2 = None permute_13: "f16[1, 1, s13, 256]" = torch.ops.aten.permute.default(view_31, [1, 2, 0, 3]); view_31 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand_4: "f16[1, 1, s13, 256]" = torch.ops.aten.expand.default(permute_13, [1, 1, arg2_1, 256]); permute_13 = None view_34: "f16[1, s13, 256]" = torch.ops.aten.reshape.default(expand_4, [1, arg2_1, 256]); expand_4 = None # File: /workspace/networks/layers/transformer.py:805 in forward, code: unfold_K = global_K.view(size_2d[0],size_2d[1],bs,ck) view_27: "f16[s9, (s13//s9), 1, 256]" = torch.ops.aten.reshape.default(view_24, [arg20_1, arg21_1, 1, 256]); view_24 = None # File: /workspace/networks/layers/transformer.py:807 in forward, code: global_K = unfold_K[::d,::d,:,:].reshape(-1,bs,ck) slice_7: "f16[((s9 + 1)//2), (s13//s9), 1, 256]" = torch.ops.aten.slice.Tensor(view_27, 0, 0, 9223372036854775807, 2); view_27 = None slice_8: "f16[((s9 + 1)//2), (((s13//s9) + 1)//2), 1, 256]" = torch.ops.aten.slice.Tensor(slice_7, 1, 0, 9223372036854775807, 2) clone_7: "f16[((s9 + 1)//2), (((s13//s9) + 1)//2), 1, 256]" = torch.ops.aten.clone.default(slice_8, memory_format = torch.contiguous_format) sym_size_int_10: "Sym(((s9 + 1)//2))" = torch.ops.aten.sym_size.int(slice_7, 0); slice_7 = None sym_size_int_11: "Sym((((s13//s9) + 1)//2))" = torch.ops.aten.sym_size.int(slice_8, 1); slice_8 = None mul_7: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = sym_size_int_10 * sym_size_int_11; sym_size_int_10 = sym_size_int_11 = None mul_8: "Sym(256*(((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = mul_7 * 256; mul_7 = None floordiv_1: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = mul_8 // 256; mul_8 = None view_29: "f16[(((s9 + 1)//2))*((((s13//s9) + 1)//2)), 1, 256]" = torch.ops.aten.reshape.default(clone_7, [floordiv_1, 1, 256]); clone_7 = floordiv_1 = None # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_32: "f16[(((s9 + 1)//2))*((((s13//s9) + 1)//2)), 1, 1, 256]" = torch.ops.aten.reshape.default(view_29, [-1, 1, 1, 256]) permute_14: "f16[1, 1, 256, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.permute.default(view_32, [1, 2, 3, 0]); view_32 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y sym_size_int_14: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = torch.ops.aten.sym_size.int(view_29, 0) expand_5: "f16[1, 1, 256, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.expand.default(permute_14, [1, 1, 256, sym_size_int_14]); permute_14 = None view_35: "f16[1, 256, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.reshape.default(expand_5, [1, 256, sym_size_int_14]); expand_5 = None bmm_2: "f16[1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.bmm.default(view_34, view_35); view_34 = view_35 = None view_36: "f16[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.reshape.default(bmm_2, [1, 1, arg2_1, sym_size_int_14]); bmm_2 = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_43: "f32[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.prims.convert_element_type.default(view_36, torch.float32); view_36 = None amax_1: "f32[1, 1, s13, 1]" = torch.ops.aten.amax.default(convert_element_type_43, [-1], True) sub_3: "f32[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.sub.Tensor(convert_element_type_43, amax_1); convert_element_type_43 = amax_1 = None exp_1: "f32[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.exp.default(sub_3); sub_3 = None sum_2: "f32[1, 1, s13, 1]" = torch.ops.aten.sum.dim_IntList(exp_1, [-1], True) div_3: "f32[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.div.Tensor(exp_1, sum_2); exp_1 = sum_2 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_44: "f16[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.prims.convert_element_type.default(div_3, torch.float16); div_3 = None expand_6: "f16[1, 1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.expand.default(convert_element_type_44, [1, 1, arg2_1, sym_size_int_14]); convert_element_type_44 = None view_37: "f16[1, s13, (((s9 + 1)//2))*((((s13//s9) + 1)//2))]" = torch.ops.aten.reshape.default(expand_6, [1, arg2_1, sym_size_int_14]); expand_6 = sym_size_int_14 = None # File: /workspace/networks/layers/transformer.py:806 in forward, code: unfold_V = global_V.view(size_2d[0],size_2d[1],bs,cv) view_28: "f16[s9, (s13//s9), 1, 256]" = torch.ops.aten.reshape.default(add_7, [arg20_1, arg21_1, 1, 256]); add_7 = arg20_1 = arg21_1 = None # File: /workspace/networks/layers/transformer.py:808 in forward, code: global_V = unfold_V[::d,::d,:,:].reshape(-1,bs,cv) slice_11: "f16[((s9 + 1)//2), (s13//s9), 1, 256]" = torch.ops.aten.slice.Tensor(view_28, 0, 0, 9223372036854775807, 2); view_28 = None slice_12: "f16[((s9 + 1)//2), (((s13//s9) + 1)//2), 1, 256]" = torch.ops.aten.slice.Tensor(slice_11, 1, 0, 9223372036854775807, 2) clone_8: "f16[((s9 + 1)//2), (((s13//s9) + 1)//2), 1, 256]" = torch.ops.aten.clone.default(slice_12, memory_format = torch.contiguous_format) sym_size_int_12: "Sym(((s9 + 1)//2))" = torch.ops.aten.sym_size.int(slice_11, 0); slice_11 = None sym_size_int_13: "Sym((((s13//s9) + 1)//2))" = torch.ops.aten.sym_size.int(slice_12, 1); slice_12 = None mul_9: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = sym_size_int_12 * sym_size_int_13; sym_size_int_12 = sym_size_int_13 = None mul_10: "Sym(256*(((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = mul_9 * 256; mul_9 = None floordiv_2: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = mul_10 // 256; mul_10 = None view_30: "f16[(((s9 + 1)//2))*((((s13//s9) + 1)//2)), 1, 256]" = torch.ops.aten.reshape.default(clone_8, [floordiv_2, 1, 256]); clone_8 = floordiv_2 = None # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_33: "f16[(((s9 + 1)//2))*((((s13//s9) + 1)//2)), 1, 1, 256]" = torch.ops.aten.reshape.default(view_30, [-1, 1, 1, 256]) permute_15: "f16[1, 1, (((s9 + 1)//2))*((((s13//s9) + 1)//2)), 256]" = torch.ops.aten.permute.default(view_33, [1, 2, 0, 3]); view_33 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y sym_size_int_16: "Sym((((s9 + 1)//2))*((((s13//s9) + 1)//2)))" = torch.ops.aten.sym_size.int(view_30, 0) expand_7: "f16[1, 1, (((s9 + 1)//2))*((((s13//s9) + 1)//2)), 256]" = torch.ops.aten.expand.default(permute_15, [1, 1, sym_size_int_16, 256]); permute_15 = None view_38: "f16[1, (((s9 + 1)//2))*((((s13//s9) + 1)//2)), 256]" = torch.ops.aten.reshape.default(expand_7, [1, sym_size_int_16, 256]); expand_7 = sym_size_int_16 = None bmm_3: "f16[1, s13, 256]" = torch.ops.aten.bmm.default(view_37, view_38); view_37 = view_38 = None view_39: "f16[1, 1, s13, 256]" = torch.ops.aten.reshape.default(bmm_3, [1, 1, arg2_1, 256]); bmm_3 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_16: "f16[s13, 1, 1, 256]" = torch.ops.aten.permute.default(view_39, [2, 0, 1, 3]); view_39 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) view_40: "f16[s13, 1, 256]" = torch.ops.aten.reshape.default(permute_16, [-1, 1, 256]); permute_16 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_41: "f16[s13, 256]" = torch.ops.aten.reshape.default(view_40, [arg2_1, 256]); view_40 = None convert_element_type_48: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg25_1, torch.float16); arg25_1 = None permute_17: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_48, [1, 0]); convert_element_type_48 = None addmm_6: "f16[s13, 256]" = torch.ops.aten.addmm.default(convert_element_type_47, view_41, permute_17); convert_element_type_47 = view_41 = permute_17 = None view_42: "f16[s13, 1, 256]" = torch.ops.aten.reshape.default(addmm_6, [arg2_1, 1, 256]); addmm_6 = arg2_1 = None return (clone_4, clone_5, clone_6, add_3, add_5, view_42, getitem_4, getitem_5, view_29, view_30)