class (torch.nn.Module): def forward(self, arg0_1: "Sym(s0)", arg1_1: "Sym(s1)", arg2_1: "f16[1, 15, 15, s0, s1]", arg3_1: "Sym(68)", arg4_1: "Sym(68)", arg5_1: "Sym(s5)", arg6_1: "f16[1, 1, 225, s0*s1]", arg7_1: "f32[1, 1, 225, s0*s1]", arg8_1: "f32[1, 256, 225]", arg9_1: "b8[1, 1, 4624, 82, 82]", arg10_1: "f16[1, 1, 256, 4624]", arg11_1: "f32[256, 256]", arg12_1: "f32[256]"): # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( view: "f16[1, 1, 225, s0*s1]" = torch.ops.aten.view.default(arg2_1, [1, 1, 225, 4624]); arg2_1 = None # add: "f16[1, 1, 225, s0*s1]" = torch.ops.aten.add.Tensor(view, arg6_1); arg6_1 = None # mul: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.mul.Tensor(arg7_1, 10000.0); arg7_1 = None sub: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.sub.Tensor(add, mul); add = mul = None convert_element_type: "f16[1, 1, 225, s0*s1]" = torch.ops.prims.convert_element_type.default(sub, torch.float16); sub = None # convert_element_type_1: "f32[1, 1, 225, s0*s1]" = torch.ops.prims.convert_element_type.default(convert_element_type, torch.float32); convert_element_type = None amax: "f32[1, 1, 1, s0*s1]" = torch.ops.aten.amax.default(convert_element_type_1, [2], True) sub_1: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.sub.Tensor(convert_element_type_1, amax); convert_element_type_1 = amax = None exp: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, 1, 1, s0*s1]" = torch.ops.aten.sum.dim_IntList(exp, [2], True) div: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) clone: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.clone.default(div); div = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) convert_element_type_2: "f16[1, 1, 225, s0*s1]" = torch.ops.prims.convert_element_type.default(clone, torch.float16) convert_element_type_3: "f16[1, 256, 225]" = torch.ops.prims.convert_element_type.default(arg8_1, torch.float16); arg8_1 = None unsqueeze: "f16[1, 1, 225, s0*s1, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_2, 4); convert_element_type_2 = None permute: "f16[1, 1, s0*s1, 1, 225]" = torch.ops.aten.permute.default(unsqueeze, [0, 1, 3, 4, 2]); unsqueeze = None unsqueeze_1: "f16[1, 256, 225, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_3, 3); convert_element_type_3 = None unsqueeze_2: "f16[1, 256, 225, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 4); unsqueeze_1 = None permute_1: "f16[1, 1, 1, 256, 225]" = torch.ops.aten.permute.default(unsqueeze_2, [3, 0, 4, 1, 2]); unsqueeze_2 = None permute_2: "f16[s0*s1, 225, 1, 1, 1]" = torch.ops.aten.permute.default(permute, [2, 4, 0, 1, 3]); permute = None sym_size_int: "Sym(s0*s1)" = torch.ops.aten.sym_size.int(view, 3); view = None view_1: "f16[1, s0*s1, 225]" = torch.ops.aten.view.default(permute_2, [1, sym_size_int, 225]); permute_2 = None permute_3: "f16[225, 1, 1, 256, 1]" = torch.ops.aten.permute.default(permute_1, [4, 0, 1, 3, 2]); permute_1 = None view_2: "f16[1, 225, 256]" = torch.ops.aten.view.default(permute_3, [1, 225, 256]); permute_3 = None bmm: "f16[1, s0*s1, 256]" = torch.ops.aten.bmm.default(view_1, view_2); view_2 = None view_3: "f16[s0*s1, 1, 1, 1, 256]" = torch.ops.aten.view.default(bmm, [sym_size_int, 1, 1, 1, 256]); bmm = None permute_4: "f16[1, 1, s0*s1, 256, 1]" = torch.ops.aten.permute.default(view_3, [2, 3, 0, 4, 1]); view_3 = None sym_size_int_1: "Sym(s0*s1)" = torch.ops.aten.sym_size.int(view_1, 1); view_1 = None view_4: "f16[1, 1, s0*s1, 256]" = torch.ops.aten.view.default(permute_4, [1, 1, sym_size_int_1, 256]); permute_4 = sym_size_int_1 = None # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( full: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.full.default([1, 1, 4624, 82, 82], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( permute_5: "f32[1, 1, s0*s1, 225]" = torch.ops.aten.permute.default(clone, [0, 1, 3, 2]) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) clone_1: "f32[1, 1, s0*s1, 225]" = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format); permute_5 = None mul_1: "Sym(225*s0*s1)" = sym_size_int * 225; sym_size_int = None floordiv: "Sym(225*s0*s1)" = mul_1 // 1; mul_1 = None view_5: "f32[225*s0*s1]" = torch.ops.aten.view.default(clone_1, [floordiv]); clone_1 = floordiv = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand: "b8[1, 1, 4624, 82, 82]" = torch.ops.aten.expand.default(arg9_1, [1, 1, -1, -1, -1]); arg9_1 = None index_put: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.index_put.default(full, [expand], view_5); full = expand = view_5 = None # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( slice_6: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.slice.Tensor(index_put, 0, 0, 9223372036854775807); index_put = None slice_7: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.slice.Tensor(slice_6, 1, 0, 9223372036854775807); slice_6 = None slice_8: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.slice.Tensor(slice_7, 2, 0, 9223372036854775807); slice_7 = None slice_9: "f32[1, 1, 4624, 68, 82]" = torch.ops.aten.slice.Tensor(slice_8, 3, 7, -7); slice_8 = None slice_10: "f32[1, 1, 4624, 68, 68]" = torch.ops.aten.slice.Tensor(slice_9, 4, 7, -7); slice_9 = None clone_2: "f32[1, 1, 4624, 68, 68]" = torch.ops.aten.clone.default(slice_10, memory_format = torch.contiguous_format); slice_10 = None view_6: "f32[1, 1, 4624, 4624]" = torch.ops.aten.view.default(clone_2, [1, 1, 4624, 4624]); clone_2 = None # permute_6: "f16[1, 1, 4624, 256]" = torch.ops.aten.permute.default(arg10_1, [0, 1, 3, 2]); arg10_1 = None convert_element_type_6: "f16[1, 1, 4624, 4624]" = torch.ops.prims.convert_element_type.default(view_6, torch.float16); view_6 = None expand_1: "f16[1, 1, 4624, 4624]" = torch.ops.aten.expand.default(convert_element_type_6, [1, 1, 4624, 4624]); convert_element_type_6 = None view_7: "f16[1, 4624, 4624]" = torch.ops.aten.view.default(expand_1, [1, 4624, 4624]); expand_1 = None expand_2: "f16[1, 1, 4624, 256]" = torch.ops.aten.expand.default(permute_6, [1, 1, 4624, 256]); permute_6 = None view_8: "f16[1, 4624, 256]" = torch.ops.aten.view.default(expand_2, [1, 4624, 256]); expand_2 = None bmm_1: "f16[1, 4624, 256]" = torch.ops.aten.bmm.default(view_7, view_8); view_7 = view_8 = None view_9: "f16[1, 1, 4624, 256]" = torch.ops.aten.view.default(bmm_1, [1, 1, 4624, 256]); bmm_1 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_1: "f16[1, 1, 4624, 256]" = torch.ops.aten.add.Tensor(view_9, view_4); view_9 = view_4 = None permute_7: "f16[4624, 1, 1, 256]" = torch.ops.aten.permute.default(add_1, [2, 0, 1, 3]); add_1 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: view_10: "f16[4624, 1, 256]" = torch.ops.aten.view.default(permute_7, [4624, 1, 256]); permute_7 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_9: "f16[256]" = torch.ops.prims.convert_element_type.default(arg12_1, torch.float16); arg12_1 = None convert_element_type_10: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg11_1, torch.float16); arg11_1 = None view_11: "f16[4624, 256]" = torch.ops.aten.view.default(view_10, [4624, 256]); view_10 = None permute_8: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_10, [1, 0]); convert_element_type_10 = None addmm: "f16[4624, 256]" = torch.ops.aten.addmm.default(convert_element_type_9, view_11, permute_8); convert_element_type_9 = view_11 = permute_8 = None view_12: "f16[4624, 1, 256]" = torch.ops.aten.view.default(addmm, [4624, 1, 256]); addmm = None return (view_12, clone)