class (torch.nn.Module): def forward(self, arg0_1: "f32[256]", arg1_1: "f32[256]", arg2_1: "f32[4624, 1, 256]", arg3_1: "f32[4624, 1, 256]", arg4_1: "f32[256, 256]", arg5_1: "f32[256]", arg6_1: "f32[256, 256]", arg7_1: "f32[256]", arg8_1: "f32[256, 256]", arg9_1: "f32[256]", arg10_1: "f32[256, 256]", arg11_1: "f32[256]", arg12_1: "f32[256]", arg13_1: "f32[256]", arg14_1: "f32[512, 256]", arg15_1: "f32[512]", arg16_1: "f32[4624, 1, 256]", arg17_1: "f32[257, 256]", arg18_1: "f32[257]", arg19_1: "f32[256, 256]", arg20_1: "f32[256]"): # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( var_mean = torch.ops.aten.var_mean.correction(arg2_1, [2], correction = 0, keepdim = True) getitem: "f32[4624, 1, 1]" = var_mean[0] getitem_1: "f32[4624, 1, 1]" = var_mean[1]; var_mean = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_24: "f16[256]" = torch.ops.prims.convert_element_type.default(arg11_1, torch.float16); arg11_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type: "f16[256]" = torch.ops.prims.convert_element_type.default(arg5_1, torch.float16); arg5_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( sub: "f32[4624, 1, 256]" = torch.ops.aten.sub.Tensor(arg2_1, getitem_1); getitem_1 = None add: "f32[4624, 1, 1]" = torch.ops.aten.add.Tensor(getitem, 1e-05); getitem = None rsqrt: "f32[4624, 1, 1]" = torch.ops.aten.rsqrt.default(add); add = None mul: "f32[4624, 1, 256]" = torch.ops.aten.mul.Tensor(sub, rsqrt); sub = rsqrt = None mul_1: "f32[4624, 1, 256]" = torch.ops.aten.mul.Tensor(mul, arg0_1); mul = arg0_1 = None add_1: "f32[4624, 1, 256]" = torch.ops.aten.add.Tensor(mul_1, arg1_1); mul_1 = arg1_1 = None # File: /workspace/networks/layers/transformer.py:752 in with_pos_embed, code: return tensor if pos is None else tensor + pos add_2: "f32[4624, 1, 256]" = torch.ops.aten.add.Tensor(add_1, arg3_1); arg3_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_2: "f16[4624, 1, 256]" = torch.ops.prims.convert_element_type.default(add_2, torch.float16) view: "f16[4624, 256]" = torch.ops.aten.reshape.default(convert_element_type_2, [4624, 256]); convert_element_type_2 = None convert_element_type_1: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg4_1, torch.float16); arg4_1 = None permute: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_1, [1, 0]); convert_element_type_1 = None # No stacktrace found for following nodes mm_default_2: "f16[4624, 256]" = torch.ops.aten.mm.default(view, permute); view = permute = None add_tensor_2: "f16[4624, 256]" = torch.ops.aten.add.Tensor(mm_default_2, convert_element_type); mm_default_2 = convert_element_type = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_1: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(add_tensor_2, [4624, 1, 256]); add_tensor_2 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div: "f16[4624, 1, 256]" = torch.ops.aten.div.Tensor(view_1, 5.656854249492381); view_1 = None # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_6: "f16[4624, 1, 8, 32]" = torch.ops.aten.reshape.default(div, [-1, 1, 8, 32]); div = None permute_3: "f16[1, 8, 4624, 32]" = torch.ops.aten.permute.default(view_6, [1, 2, 0, 3]); view_6 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand: "f16[1, 8, 4624, 32]" = torch.ops.aten.expand.default(permute_3, [1, 8, 4624, 32]); permute_3 = None view_9: "f16[8, 4624, 32]" = torch.ops.aten.reshape.default(expand, [8, 4624, 32]); expand = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_6: "f16[256]" = torch.ops.prims.convert_element_type.default(arg7_1, torch.float16); arg7_1 = None convert_element_type_8: "f16[4624, 1, 256]" = torch.ops.prims.convert_element_type.default(add_2, torch.float16); add_2 = None view_2: "f16[4624, 256]" = torch.ops.aten.reshape.default(convert_element_type_8, [4624, 256]); convert_element_type_8 = None convert_element_type_7: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg6_1, torch.float16); arg6_1 = None permute_1: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_7, [1, 0]); convert_element_type_7 = None addmm_1: "f16[4624, 256]" = torch.ops.aten.addmm.default(convert_element_type_6, view_2, permute_1); convert_element_type_6 = view_2 = permute_1 = None view_3: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(addmm_1, [4624, 1, 256]); addmm_1 = None # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_7: "f16[4624, 1, 8, 32]" = torch.ops.aten.reshape.default(view_3, [-1, 1, 8, 32]); view_3 = None permute_4: "f16[1, 8, 32, 4624]" = torch.ops.aten.permute.default(view_7, [1, 2, 3, 0]); view_7 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand_1: "f16[1, 8, 32, 4624]" = torch.ops.aten.expand.default(permute_4, [1, 8, 32, 4624]); permute_4 = None view_10: "f16[8, 32, 4624]" = torch.ops.aten.reshape.default(expand_1, [8, 32, 4624]); expand_1 = None bmm: "f16[8, 4624, 4624]" = torch.ops.aten.bmm.default(view_9, view_10); view_9 = view_10 = None view_11: "f16[1, 8, 4624, 4624]" = torch.ops.aten.reshape.default(bmm, [1, 8, 4624, 4624]); bmm = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_20: "f32[1, 8, 4624, 4624]" = torch.ops.prims.convert_element_type.default(view_11, torch.float32); view_11 = None amax: "f32[1, 8, 4624, 1]" = torch.ops.aten.amax.default(convert_element_type_20, [-1], True) sub_1: "f32[1, 8, 4624, 4624]" = torch.ops.aten.sub.Tensor(convert_element_type_20, amax); convert_element_type_20 = amax = None exp: "f32[1, 8, 4624, 4624]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, 8, 4624, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) div_1: "f32[1, 8, 4624, 4624]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_21: "f16[1, 8, 4624, 4624]" = torch.ops.prims.convert_element_type.default(div_1, torch.float16); div_1 = None expand_2: "f16[1, 8, 4624, 4624]" = torch.ops.aten.expand.default(convert_element_type_21, [1, 8, 4624, 4624]); convert_element_type_21 = None view_12: "f16[8, 4624, 4624]" = torch.ops.aten.reshape.default(expand_2, [8, 4624, 4624]); expand_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_12: "f16[256]" = torch.ops.prims.convert_element_type.default(arg9_1, torch.float16); arg9_1 = None convert_element_type_14: "f16[4624, 1, 256]" = torch.ops.prims.convert_element_type.default(add_1, torch.float16); add_1 = None view_4: "f16[4624, 256]" = torch.ops.aten.reshape.default(convert_element_type_14, [4624, 256]); convert_element_type_14 = None convert_element_type_13: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg8_1, torch.float16); arg8_1 = None permute_2: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_13, [1, 0]); convert_element_type_13 = None addmm_2: "f16[4624, 256]" = torch.ops.aten.addmm.default(convert_element_type_12, view_4, permute_2); convert_element_type_12 = view_4 = permute_2 = None view_5: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(addmm_2, [4624, 1, 256]); addmm_2 = None # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_8: "f16[4624, 1, 8, 32]" = torch.ops.aten.reshape.default(view_5, [-1, 1, 8, 32]); view_5 = None permute_5: "f16[1, 8, 4624, 32]" = torch.ops.aten.permute.default(view_8, [1, 2, 0, 3]); view_8 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y expand_3: "f16[1, 8, 4624, 32]" = torch.ops.aten.expand.default(permute_5, [1, 8, 4624, 32]); permute_5 = None view_13: "f16[8, 4624, 32]" = torch.ops.aten.reshape.default(expand_3, [8, 4624, 32]); expand_3 = None bmm_1: "f16[8, 4624, 32]" = torch.ops.aten.bmm.default(view_12, view_13); view_12 = view_13 = None view_14: "f16[1, 8, 4624, 32]" = torch.ops.aten.reshape.default(bmm_1, [1, 8, 4624, 32]); bmm_1 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_6: "f16[4624, 1, 8, 32]" = torch.ops.aten.permute.default(view_14, [2, 0, 1, 3]); view_14 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) clone_1: "f16[4624, 1, 8, 32]" = torch.ops.aten.clone.default(permute_6, memory_format = torch.contiguous_format); permute_6 = None view_15: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(clone_1, [4624, 1, 256]); clone_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_16: "f16[4624, 256]" = torch.ops.aten.reshape.default(view_15, [4624, 256]); view_15 = None convert_element_type_25: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg10_1, torch.float16); arg10_1 = None permute_7: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_25, [1, 0]); convert_element_type_25 = None # No stacktrace found for following nodes mm_default_1: "f16[4624, 256]" = torch.ops.aten.mm.default(view_16, permute_7); view_16 = permute_7 = None add_tensor_1: "f16[4624, 256]" = torch.ops.aten.add.Tensor(mm_default_1, convert_element_type_24); mm_default_1 = convert_element_type_24 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_17: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(add_tensor_1, [4624, 1, 256]); add_tensor_1 = None # File: /workspace/networks/layers/transformer.py:772 in forward, code: tgt = tgt + self.droppath(tgt2) add_3: "f32[4624, 1, 256]" = torch.ops.aten.add.Tensor(arg2_1, view_17); arg2_1 = view_17 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( var_mean_1 = torch.ops.aten.var_mean.correction(add_3, [2], correction = 0, keepdim = True) getitem_2: "f32[4624, 1, 1]" = var_mean_1[0] getitem_3: "f32[4624, 1, 1]" = var_mean_1[1]; var_mean_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_29: "f16[512]" = torch.ops.prims.convert_element_type.default(arg15_1, torch.float16); arg15_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/normalization.py:201 in forward, code: return F.layer_norm( sub_2: "f32[4624, 1, 256]" = torch.ops.aten.sub.Tensor(add_3, getitem_3); getitem_3 = None add_4: "f32[4624, 1, 1]" = torch.ops.aten.add.Tensor(getitem_2, 1e-05); getitem_2 = None rsqrt_1: "f32[4624, 1, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None mul_2: "f32[4624, 1, 256]" = torch.ops.aten.mul.Tensor(sub_2, rsqrt_1); sub_2 = rsqrt_1 = None mul_3: "f32[4624, 1, 256]" = torch.ops.aten.mul.Tensor(mul_2, arg12_1); mul_2 = arg12_1 = None add_5: "f32[4624, 1, 256]" = torch.ops.aten.add.Tensor(mul_3, arg13_1); mul_3 = arg13_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_31: "f16[4624, 1, 256]" = torch.ops.prims.convert_element_type.default(add_5, torch.float16) view_18: "f16[4624, 256]" = torch.ops.aten.reshape.default(convert_element_type_31, [4624, 256]); convert_element_type_31 = None convert_element_type_30: "f16[512, 256]" = torch.ops.prims.convert_element_type.default(arg14_1, torch.float16); arg14_1 = None permute_8: "f16[256, 512]" = torch.ops.aten.permute.default(convert_element_type_30, [1, 0]); convert_element_type_30 = None addmm_4: "f16[4624, 512]" = torch.ops.aten.addmm.default(convert_element_type_29, view_18, permute_8); convert_element_type_29 = view_18 = permute_8 = None view_19: "f16[4624, 1, 512]" = torch.ops.aten.reshape.default(addmm_4, [4624, 1, 512]); addmm_4 = None # File: /workspace/networks/layers/transformer.py:778 in forward, code: curr_QV = torch.split(curr_QV, self.d_model, dim=2) split = torch.ops.aten.split.Tensor(view_19, 256, 2); view_19 = None getitem_4: "f16[4624, 1, 256]" = split[0] getitem_5: "f16[4624, 1, 256]" = split[1]; split = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_35: "f16[257]" = torch.ops.prims.convert_element_type.default(arg18_1, torch.float16); arg18_1 = None convert_element_type_37: "f16[4624, 1, 256]" = torch.ops.prims.convert_element_type.default(arg16_1, torch.float16); arg16_1 = None view_21: "f16[4624, 256]" = torch.ops.aten.reshape.default(convert_element_type_37, [4624, 256]); convert_element_type_37 = None convert_element_type_36: "f16[257, 256]" = torch.ops.prims.convert_element_type.default(arg17_1, torch.float16); arg17_1 = None permute_10: "f16[256, 257]" = torch.ops.aten.permute.default(convert_element_type_36, [1, 0]); convert_element_type_36 = None # No stacktrace found for following nodes mm_default: "f16[4624, 257]" = torch.ops.aten.mm.default(view_21, permute_10); view_21 = permute_10 = None add_tensor: "f16[4624, 257]" = torch.ops.aten.add.Tensor(mm_default, convert_element_type_35); mm_default = convert_element_type_35 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_22: "f16[4624, 1, 257]" = torch.ops.aten.reshape.default(add_tensor, [4624, 1, 257]); add_tensor = None # File: /workspace/networks/layers/transformer.py:855 in fuse_key_value_id, code: ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) split_with_sizes = torch.ops.aten.split_with_sizes.default(view_22, [1, 256], 2); view_22 = None getitem_6: "f16[4624, 1, 1]" = split_with_sizes[0] getitem_7: "f16[4624, 1, 256]" = split_with_sizes[1]; split_with_sizes = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_20: "f16[68, 68, 1, 256]" = torch.ops.aten.reshape.default(getitem_4, [68, 68, 1, 256]) permute_9: "f16[1, 256, 68, 68]" = torch.ops.aten.permute.default(view_20, [2, 3, 0, 1]); view_20 = None clone_2: "f16[1, 256, 68, 68]" = torch.ops.aten.clone.default(permute_9, memory_format = torch.contiguous_format); permute_9 = None # File: /workspace/networks/layers/transformer.py:857 in fuse_key_value_id, code: K = key.view(-1, bs, self.att_nhead, self.d_model // view_23: "f16[4624, 1, 1, 256]" = torch.ops.aten.reshape.default(getitem_4, [-1, 1, 1, 256]) # File: /workspace/networks/layers/transformer.py:858 in fuse_key_value_id, code: self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) tanh: "f16[4624, 1, 1]" = torch.ops.aten.tanh.default(getitem_6); getitem_6 = None add_6: "f16[4624, 1, 1]" = torch.ops.aten.add.Tensor(tanh, 1); tanh = None unsqueeze: "f16[4624, 1, 1, 1]" = torch.ops.aten.unsqueeze.default(add_6, -1); add_6 = None # File: /workspace/networks/layers/transformer.py:857 in fuse_key_value_id, code: K = key.view(-1, bs, self.att_nhead, self.d_model // mul_4: "f16[4624, 1, 1, 256]" = torch.ops.aten.mul.Tensor(view_23, unsqueeze); view_23 = unsqueeze = None # File: /workspace/networks/layers/transformer.py:859 in fuse_key_value_id, code: K = K.view(-1, bs, self.d_model) view_24: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(mul_4, [-1, 1, 256]); mul_4 = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_25: "f16[68, 68, 1, 256]" = torch.ops.aten.reshape.default(view_24, [68, 68, 1, 256]) permute_11: "f16[1, 256, 68, 68]" = torch.ops.aten.permute.default(view_25, [2, 3, 0, 1]); view_25 = None clone_3: "f16[1, 256, 68, 68]" = torch.ops.aten.clone.default(permute_11, memory_format = torch.contiguous_format); permute_11 = None # File: /workspace/networks/layers/transformer.py:860 in fuse_key_value_id, code: V = value + ID_V add_7: "f16[4624, 1, 256]" = torch.ops.aten.add.Tensor(getitem_5, getitem_7); getitem_7 = None # File: /workspace/networks/layers/basic.py:93 in seq_to_2d, code: tensor = tensor.view(h, w, n, c).permute(2, 3, 0, 1).contiguous() view_26: "f16[68, 68, 1, 256]" = torch.ops.aten.reshape.default(add_7, [68, 68, 1, 256]) permute_12: "f16[1, 256, 68, 68]" = torch.ops.aten.permute.default(view_26, [2, 3, 0, 1]); view_26 = None clone_4: "f16[1, 256, 68, 68]" = torch.ops.aten.clone.default(permute_12, memory_format = torch.contiguous_format); permute_12 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_47: "f16[256]" = torch.ops.prims.convert_element_type.default(arg20_1, torch.float16); arg20_1 = None # File: /workspace/networks/layers/attention.py:80 in forward, code: Q = Q / self.T div_2: "f16[4624, 1, 256]" = torch.ops.aten.div.Tensor(getitem_4, 16.0) # File: /workspace/networks/layers/attention.py:90 in forward, code: Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3) view_27: "f16[4624, 1, 1, 256]" = torch.ops.aten.reshape.default(div_2, [-1, 1, 1, 256]); div_2 = None permute_13: "f16[1, 1, 4624, 256]" = torch.ops.aten.permute.default(view_27, [1, 2, 0, 3]); view_27 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand_4: "f16[1, 1, 4624, 256]" = torch.ops.aten.expand.default(permute_13, [1, 1, 4624, 256]); permute_13 = None view_30: "f16[1, 4624, 256]" = torch.ops.aten.reshape.default(expand_4, [1, 4624, 256]); expand_4 = None # File: /workspace/networks/layers/attention.py:91 in forward, code: K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0) view_28: "f16[4624, 1, 1, 256]" = torch.ops.aten.reshape.default(view_24, [-1, 1, 1, 256]) permute_14: "f16[1, 1, 256, 4624]" = torch.ops.aten.permute.default(view_28, [1, 2, 3, 0]); view_28 = None # File: /workspace/networks/layers/attention.py:8 in multiply_by_ychunks, code: return x @ y expand_5: "f16[1, 1, 256, 4624]" = torch.ops.aten.expand.default(permute_14, [1, 1, 256, 4624]); permute_14 = None view_31: "f16[1, 256, 4624]" = torch.ops.aten.reshape.default(expand_5, [1, 256, 4624]); expand_5 = None bmm_2: "f16[1, 4624, 4624]" = torch.ops.aten.bmm.default(view_30, view_31); view_30 = view_31 = None view_32: "f16[1, 1, 4624, 4624]" = torch.ops.aten.reshape.default(bmm_2, [1, 1, 4624, 4624]); bmm_2 = None # File: /workspace/networks/layers/attention.py:114 in forward, code: attn = torch.softmax(QK, dim=-1) convert_element_type_43: "f32[1, 1, 4624, 4624]" = torch.ops.prims.convert_element_type.default(view_32, torch.float32); view_32 = None amax_1: "f32[1, 1, 4624, 1]" = torch.ops.aten.amax.default(convert_element_type_43, [-1], True) sub_3: "f32[1, 1, 4624, 4624]" = torch.ops.aten.sub.Tensor(convert_element_type_43, amax_1); convert_element_type_43 = amax_1 = None exp_1: "f32[1, 1, 4624, 4624]" = torch.ops.aten.exp.default(sub_3); sub_3 = None sum_2: "f32[1, 1, 4624, 1]" = torch.ops.aten.sum.dim_IntList(exp_1, [-1], True) div_3: "f32[1, 1, 4624, 4624]" = torch.ops.aten.div.Tensor(exp_1, sum_2); exp_1 = sum_2 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y convert_element_type_44: "f16[1, 1, 4624, 4624]" = torch.ops.prims.convert_element_type.default(div_3, torch.float16); div_3 = None expand_6: "f16[1, 1, 4624, 4624]" = torch.ops.aten.expand.default(convert_element_type_44, [1, 1, 4624, 4624]); convert_element_type_44 = None view_33: "f16[1, 4624, 4624]" = torch.ops.aten.reshape.default(expand_6, [1, 4624, 4624]); expand_6 = None # File: /workspace/networks/layers/attention.py:92 in forward, code: V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3) view_29: "f16[4624, 1, 1, 256]" = torch.ops.aten.reshape.default(add_7, [-1, 1, 1, 256]) permute_15: "f16[1, 1, 4624, 256]" = torch.ops.aten.permute.default(view_29, [1, 2, 0, 3]); view_29 = None # File: /workspace/networks/layers/attention.py:14 in multiply_by_xchunks, code: return x @ y expand_7: "f16[1, 1, 4624, 256]" = torch.ops.aten.expand.default(permute_15, [1, 1, 4624, 256]); permute_15 = None view_34: "f16[1, 4624, 256]" = torch.ops.aten.reshape.default(expand_7, [1, 4624, 256]); expand_7 = None bmm_3: "f16[1, 4624, 256]" = torch.ops.aten.bmm.default(view_33, view_34); view_33 = view_34 = None view_35: "f16[1, 1, 4624, 256]" = torch.ops.aten.reshape.default(bmm_3, [1, 1, 4624, 256]); bmm_3 = None # File: /workspace/networks/layers/attention.py:120 in forward, code: self.qk_chunks).permute(2, 0, 1, 3) permute_16: "f16[4624, 1, 1, 256]" = torch.ops.aten.permute.default(view_35, [2, 0, 1, 3]); view_35 = None # File: /workspace/networks/layers/attention.py:122 in forward, code: outputs = outputs.reshape(-1, bs, self.d_model) view_36: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(permute_16, [-1, 1, 256]); permute_16 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_37: "f16[4624, 256]" = torch.ops.aten.reshape.default(view_36, [4624, 256]); view_36 = None convert_element_type_48: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg19_1, torch.float16); arg19_1 = None permute_17: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_48, [1, 0]); convert_element_type_48 = None addmm_6: "f16[4624, 256]" = torch.ops.aten.addmm.default(convert_element_type_47, view_37, permute_17); convert_element_type_47 = view_37 = permute_17 = None view_38: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(addmm_6, [4624, 1, 256]); addmm_6 = None return (clone_2, clone_3, clone_4, add_3, add_5, view_38, getitem_4, getitem_5, view_24, add_7)