class (torch.nn.Module): def forward(self, arg0_1: "Sym(s0)", arg1_1: "Sym(s1)", arg2_1: "f16[1, 15, 15, s0, s1]", arg3_1: "Sym(s2)", arg4_1: "Sym(s3)", arg5_1: "f16[1, 1, 225, s0*s1]", arg6_1: "f32[1, 1, 225, s0*s1]", arg7_1: "f32[1, 256, 225]", arg8_1: "f16[1, 1, 256, s2*s3]", arg9_1: "f32[256, 256]", arg10_1: "f32[256]"): # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_9: "f16[256]" = torch.ops.prims.convert_element_type.default(arg10_1, torch.float16); arg10_1 = None # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( mul: "Sym(s2*s3)" = arg3_1 * arg4_1 # File: /workspace/networks/layers/attention.py:399 in local2global, code: torch.arange(0, pad_height, device=local_attn.device), add_1: "Sym(s2 + 14)" = arg3_1 + 14 # File: /workspace/networks/layers/attention.py:400 in local2global, code: torch.arange(0, pad_width, device=local_attn.device) add_3: "Sym(s3 + 14)" = arg4_1 + 14 # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( full: "f32[1, 1, s2*s3, s2 + 14, s3 + 14]" = torch.ops.aten.full.default([1, 1, mul, add_1, add_3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) # File: /workspace/networks/layers/attention.py:403 in local2global, code: torch.arange(0, height, device=local_attn.device), iota_2: "i64[s2]" = torch.ops.prims.iota.default(arg3_1, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:402 in local2global, code: qy, qx = torch.meshgrid([ view_7: "i64[s2, 1]" = torch.ops.aten.reshape.default(iota_2, [-1, 1]); iota_2 = None expand_2: "i64[s2, s3]" = torch.ops.aten.expand.default(view_7, [arg3_1, arg4_1]); view_7 = None # File: /workspace/networks/layers/attention.py:407 in local2global, code: offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis clone_1: "i64[s2, s3]" = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None view_9: "i64[s2*s3, 1]" = torch.ops.aten.reshape.default(clone_1, [mul, 1]); clone_1 = None # File: /workspace/networks/layers/attention.py:399 in local2global, code: torch.arange(0, pad_height, device=local_attn.device), iota: "i64[s2 + 14]" = torch.ops.prims.iota.default(add_1, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:398 in local2global, code: ky, kx = torch.meshgrid([ view_5: "i64[s2 + 14, 1]" = torch.ops.aten.reshape.default(iota, [-1, 1]); iota = None expand: "i64[s2 + 14, s3 + 14]" = torch.ops.aten.expand.default(view_5, [add_1, add_3]); view_5 = None # File: /workspace/networks/layers/attention.py:407 in local2global, code: offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis clone_2: "i64[s2 + 14, s3 + 14]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None mul_3: "Sym(s2*s3 + 14*s2 + 14*s3 + 196)" = add_1 * add_3 view_10: "i64[1, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.reshape.default(clone_2, [1, mul_3]); clone_2 = None sub_10: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.sub.Tensor(view_9, view_10); view_9 = view_10 = None add_7: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.add.Tensor(sub_10, 7); sub_10 = None # File: /workspace/networks/layers/attention.py:410 in local2global, code: local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= abs_1: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.abs.default(add_7); add_7 = None le: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.le.Scalar(abs_1, 7); abs_1 = None # File: /workspace/networks/layers/attention.py:404 in local2global, code: torch.arange(0, width, device=local_attn.device) iota_3: "i64[s3]" = torch.ops.prims.iota.default(arg4_1, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:402 in local2global, code: qy, qx = torch.meshgrid([ view_8: "i64[1, s3]" = torch.ops.aten.reshape.default(iota_3, [1, -1]); iota_3 = None expand_3: "i64[s2, s3]" = torch.ops.aten.expand.default(view_8, [arg3_1, arg4_1]); view_8 = arg3_1 = arg4_1 = None # File: /workspace/networks/layers/attention.py:408 in local2global, code: offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis clone_3: "i64[s2, s3]" = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None view_11: "i64[s2*s3, 1]" = torch.ops.aten.reshape.default(clone_3, [mul, 1]); clone_3 = None # File: /workspace/networks/layers/attention.py:400 in local2global, code: torch.arange(0, pad_width, device=local_attn.device) iota_1: "i64[s3 + 14]" = torch.ops.prims.iota.default(add_3, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:398 in local2global, code: ky, kx = torch.meshgrid([ view_6: "i64[1, s3 + 14]" = torch.ops.aten.reshape.default(iota_1, [1, -1]); iota_1 = None expand_1: "i64[s2 + 14, s3 + 14]" = torch.ops.aten.expand.default(view_6, [add_1, add_3]); view_6 = None # File: /workspace/networks/layers/attention.py:408 in local2global, code: offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis clone_4: "i64[s2 + 14, s3 + 14]" = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None view_12: "i64[1, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.reshape.default(clone_4, [1, mul_3]); clone_4 = mul_3 = None sub_11: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.sub.Tensor(view_11, view_12); view_11 = view_12 = None add_8: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.add.Tensor(sub_11, 7); sub_11 = None # File: /workspace/networks/layers/attention.py:410 in local2global, code: local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= abs_2: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.abs.default(add_8); add_8 = None le_1: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.le.Scalar(abs_2, 7); abs_2 = None bitwise_and: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196]" = torch.ops.aten.bitwise_and.Tensor(le, le_1); le = le_1 = None # File: /workspace/networks/layers/attention.py:412 in local2global, code: local_mask = local_mask.view(1, 1, height * width, pad_height, view_13: "b8[1, 1, s2*s3, s2 + 14, s3 + 14]" = torch.ops.aten.reshape.default(bitwise_and, [1, 1, mul, add_1, add_3]); bitwise_and = add_1 = add_3 = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand_4: "b8[1, 1, s2*s3, s2 + 14, s3 + 14]" = torch.ops.aten.expand.default(view_13, [1, 1, -1, -1, -1]) # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( view: "f16[1, 1, 225, s0*s1]" = torch.ops.aten.reshape.default(arg2_1, [1, 1, 225, mul]); arg2_1 = None # add: "f16[1, 1, 225, s0*s1]" = torch.ops.aten.add.Tensor(view, arg5_1); arg5_1 = None # mul_1: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.mul.Tensor(arg6_1, 10000.0); arg6_1 = None sub: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.sub.Tensor(add, mul_1); add = mul_1 = None # amax: "f32[1, 1, 1, s0*s1]" = torch.ops.aten.amax.default(sub, [2], True) sub_1: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.sub.Tensor(sub, amax); sub = amax = None exp: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, 1, 1, s0*s1]" = torch.ops.aten.sum.dim_IntList(exp, [2], True) div: "f32[1, 1, 225, s0*s1]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( permute_5: "f32[1, 1, s0*s1, 225]" = torch.ops.aten.permute.default(div, [0, 1, 3, 2]) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) clone_5: "f32[1, 1, s0*s1, 225]" = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format); permute_5 = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) sym_size_int: "Sym(s0*s1)" = torch.ops.aten.sym_size.int(view, 3); view = None # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) mul_8: "Sym(225*s0*s1)" = sym_size_int * 225 floordiv_8: "Sym(225*s0*s1)" = mul_8 // 1; mul_8 = None view_14: "f32[225*s0*s1]" = torch.ops.aten.reshape.default(clone_5, [floordiv_8]); clone_5 = floordiv_8 = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, index_put: "f32[1, 1, s2*s3, s2 + 14, s3 + 14]" = torch.ops.aten.index_put_.default(full, [expand_4], view_14); full = expand_4 = view_14 = None # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( slice_9: "f32[1, 1, s2*s3, s2, s3 + 14]" = torch.ops.aten.slice.Tensor(index_put, 3, 7, -7); index_put = None slice_10: "f32[1, 1, s2*s3, s2, s3]" = torch.ops.aten.slice.Tensor(slice_9, 4, 7, -7); slice_9 = None clone_6: "f32[1, 1, s2*s3, s2, s3]" = torch.ops.aten.clone.default(slice_10, memory_format = torch.contiguous_format); slice_10 = None view_15: "f32[1, 1, s2*s3, s2*s3]" = torch.ops.aten.reshape.default(clone_6, [1, 1, mul, mul]); clone_6 = None # convert_element_type_6: "f16[1, 1, s2*s3, s2*s3]" = torch.ops.prims.convert_element_type.default(view_15, torch.float16); view_15 = None expand_5: "f16[1, 1, s2*s3, s2*s3]" = torch.ops.aten.expand.default(convert_element_type_6, [1, 1, mul, mul]); convert_element_type_6 = None view_16: "f16[1, s2*s3, s2*s3]" = torch.ops.aten.reshape.default(expand_5, [1, mul, mul]); expand_5 = None permute_6: "f16[1, 1, s2*s3, 256]" = torch.ops.aten.permute.default(arg8_1, [0, 1, 3, 2]); arg8_1 = None expand_6: "f16[1, 1, s2*s3, 256]" = torch.ops.aten.expand.default(permute_6, [1, 1, mul, 256]); permute_6 = None view_17: "f16[1, s2*s3, 256]" = torch.ops.aten.reshape.default(expand_6, [1, mul, 256]); expand_6 = None bmm_1: "f16[1, s2*s3, 256]" = torch.ops.aten.bmm.default(view_16, view_17); view_16 = view_17 = None view_18: "f16[1, 1, s2*s3, 256]" = torch.ops.aten.reshape.default(bmm_1, [1, 1, mul, 256]); bmm_1 = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) convert_element_type_2: "f16[1, 1, 225, s0*s1]" = torch.ops.prims.convert_element_type.default(div, torch.float16) unsqueeze: "f16[1, 1, 225, s0*s1, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_2, 4); convert_element_type_2 = None permute: "f16[1, 1, s0*s1, 1, 225]" = torch.ops.aten.permute.default(unsqueeze, [0, 1, 3, 4, 2]); unsqueeze = None permute_2: "f16[s0*s1, 225, 1, 1, 1]" = torch.ops.aten.permute.default(permute, [2, 4, 0, 1, 3]); permute = None view_1: "f16[1, s0*s1, 225]" = torch.ops.aten.reshape.default(permute_2, [1, sym_size_int, 225]); permute_2 = None convert_element_type_3: "f16[1, 256, 225]" = torch.ops.prims.convert_element_type.default(arg7_1, torch.float16); arg7_1 = None unsqueeze_1: "f16[1, 256, 225, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_3, 3); convert_element_type_3 = None unsqueeze_2: "f16[1, 256, 225, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 4); unsqueeze_1 = None permute_1: "f16[1, 1, 1, 256, 225]" = torch.ops.aten.permute.default(unsqueeze_2, [3, 0, 4, 1, 2]); unsqueeze_2 = None permute_3: "f16[225, 1, 1, 256, 1]" = torch.ops.aten.permute.default(permute_1, [4, 0, 1, 3, 2]); permute_1 = None view_2: "f16[1, 225, 256]" = torch.ops.aten.reshape.default(permute_3, [1, 225, 256]); permute_3 = None bmm: "f16[1, s0*s1, 256]" = torch.ops.aten.bmm.default(view_1, view_2); view_2 = None view_3: "f16[s0*s1, 1, 1, 1, 256]" = torch.ops.aten.reshape.default(bmm, [sym_size_int, 1, 1, 1, 256]); bmm = sym_size_int = None permute_4: "f16[1, 1, s0*s1, 256, 1]" = torch.ops.aten.permute.default(view_3, [2, 3, 0, 4, 1]); view_3 = None sym_size_int_1: "Sym(s0*s1)" = torch.ops.aten.sym_size.int(view_1, 1); view_1 = None view_4: "f16[1, 1, s0*s1, 256]" = torch.ops.aten.reshape.default(permute_4, [1, 1, sym_size_int_1, 256]); permute_4 = sym_size_int_1 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_9: "f16[1, 1, s2*s3, 256]" = torch.ops.aten.add.Tensor(view_18, view_4); view_18 = view_4 = None permute_7: "f16[s2*s3, 1, 1, 256]" = torch.ops.aten.permute.default(add_9, [2, 0, 1, 3]); add_9 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: view_19: "f16[s2*s3, 1, 256]" = torch.ops.aten.reshape.default(permute_7, [mul, 1, 256]); permute_7 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_20: "f16[s2*s3, 256]" = torch.ops.aten.reshape.default(view_19, [mul, 256]); view_19 = None convert_element_type_10: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg9_1, torch.float16); arg9_1 = None permute_8: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_10, [1, 0]); convert_element_type_10 = None addmm: "f16[s2*s3, 256]" = torch.ops.aten.addmm.default(convert_element_type_9, view_20, permute_8); convert_element_type_9 = view_20 = permute_8 = None view_21: "f16[s2*s3, 1, 256]" = torch.ops.aten.reshape.default(addmm, [mul, 1, 256]); addmm = mul = None return (view_21, div, view_13)