class (torch.nn.Module): def forward(self, arg0_1: "f16[1, 15, 15, 68, 68]", arg1_1: "f16[1, 1, 225, 4624]", arg2_1: "f32[1, 1, 225, 4624]", arg3_1: "f32[1, 256, 225]", arg4_1: "f16[1, 1, 256, 4624]", arg5_1: "f32[256, 256]", arg6_1: "f32[256]"): # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) convert_element_type_9: "f16[256]" = torch.ops.prims.convert_element_type.default(arg6_1, torch.float16); arg6_1 = None # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( full_default: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.full.default([1, 1, 4624, 82, 82], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) # File: /workspace/networks/layers/attention.py:403 in local2global, code: torch.arange(0, height, device=local_attn.device), iota_2: "i64[68]" = torch.ops.prims.iota.default(68, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:402 in local2global, code: qy, qx = torch.meshgrid([ view_7: "i64[68, 1]" = torch.ops.aten.reshape.default(iota_2, [-1, 1]); iota_2 = None expand_2: "i64[68, 68]" = torch.ops.aten.expand.default(view_7, [68, 68]); view_7 = None # File: /workspace/networks/layers/attention.py:407 in local2global, code: offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis clone_1: "i64[68, 68]" = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format); expand_2 = None view_9: "i64[4624, 1]" = torch.ops.aten.reshape.default(clone_1, [4624, 1]); clone_1 = None # File: /workspace/networks/layers/attention.py:399 in local2global, code: torch.arange(0, pad_height, device=local_attn.device), iota: "i64[82]" = torch.ops.prims.iota.default(82, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:398 in local2global, code: ky, kx = torch.meshgrid([ view_5: "i64[82, 1]" = torch.ops.aten.reshape.default(iota, [-1, 1]); iota = None expand: "i64[82, 82]" = torch.ops.aten.expand.default(view_5, [82, 82]); view_5 = None # File: /workspace/networks/layers/attention.py:407 in local2global, code: offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis clone_2: "i64[82, 82]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None view_10: "i64[1, 6724]" = torch.ops.aten.reshape.default(clone_2, [1, 6724]); clone_2 = None sub_2: "i64[4624, 6724]" = torch.ops.aten.sub.Tensor(view_9, view_10); view_9 = view_10 = None add_1: "i64[4624, 6724]" = torch.ops.aten.add.Tensor(sub_2, 7); sub_2 = None # File: /workspace/networks/layers/attention.py:410 in local2global, code: local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= abs_1: "i64[4624, 6724]" = torch.ops.aten.abs.default(add_1); add_1 = None le: "b8[4624, 6724]" = torch.ops.aten.le.Scalar(abs_1, 7); abs_1 = None # File: /workspace/networks/layers/attention.py:404 in local2global, code: torch.arange(0, width, device=local_attn.device) iota_3: "i64[68]" = torch.ops.prims.iota.default(68, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:402 in local2global, code: qy, qx = torch.meshgrid([ view_8: "i64[1, 68]" = torch.ops.aten.reshape.default(iota_3, [1, -1]); iota_3 = None expand_3: "i64[68, 68]" = torch.ops.aten.expand.default(view_8, [68, 68]); view_8 = None # File: /workspace/networks/layers/attention.py:408 in local2global, code: offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis clone_3: "i64[68, 68]" = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None view_11: "i64[4624, 1]" = torch.ops.aten.reshape.default(clone_3, [4624, 1]); clone_3 = None # File: /workspace/networks/layers/attention.py:400 in local2global, code: torch.arange(0, pad_width, device=local_attn.device) iota_1: "i64[82]" = torch.ops.prims.iota.default(82, start = 0, step = 1, dtype = torch.int64, device = device(type='cuda', index=0), requires_grad = False) # File: /workspace/networks/layers/attention.py:398 in local2global, code: ky, kx = torch.meshgrid([ view_6: "i64[1, 82]" = torch.ops.aten.reshape.default(iota_1, [1, -1]); iota_1 = None expand_1: "i64[82, 82]" = torch.ops.aten.expand.default(view_6, [82, 82]); view_6 = None # File: /workspace/networks/layers/attention.py:408 in local2global, code: offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis clone_4: "i64[82, 82]" = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None view_12: "i64[1, 6724]" = torch.ops.aten.reshape.default(clone_4, [1, 6724]); clone_4 = None sub_3: "i64[4624, 6724]" = torch.ops.aten.sub.Tensor(view_11, view_12); view_11 = view_12 = None add_2: "i64[4624, 6724]" = torch.ops.aten.add.Tensor(sub_3, 7); sub_3 = None # File: /workspace/networks/layers/attention.py:410 in local2global, code: local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= abs_2: "i64[4624, 6724]" = torch.ops.aten.abs.default(add_2); add_2 = None le_1: "b8[4624, 6724]" = torch.ops.aten.le.Scalar(abs_2, 7); abs_2 = None bitwise_and: "b8[4624, 6724]" = torch.ops.aten.bitwise_and.Tensor(le, le_1); le = le_1 = None # File: /workspace/networks/layers/attention.py:412 in local2global, code: local_mask = local_mask.view(1, 1, height * width, pad_height, view_13: "b8[1, 1, 4624, 82, 82]" = torch.ops.aten.reshape.default(bitwise_and, [1, 1, 4624, 82, 82]); bitwise_and = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand_4: "b8[1, 1, 4624, 82, 82]" = torch.ops.aten.expand.default(view_13, [1, 1, -1, -1, -1]) # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( view: "f16[1, 1, 225, 4624]" = torch.ops.aten.reshape.default(arg0_1, [1, 1, 225, 4624]); arg0_1 = None # add: "f16[1, 1, 225, 4624]" = torch.ops.aten.add.Tensor(view, arg1_1); view = arg1_1 = None # mul: "f32[1, 1, 225, 4624]" = torch.ops.aten.mul.Tensor(arg2_1, 10000.0); arg2_1 = None sub: "f32[1, 1, 225, 4624]" = torch.ops.aten.sub.Tensor(add, mul); add = mul = None # amax: "f32[1, 1, 1, 4624]" = torch.ops.aten.amax.default(sub, [2], True) sub_1: "f32[1, 1, 225, 4624]" = torch.ops.aten.sub.Tensor(sub, amax); sub = amax = None exp: "f32[1, 1, 225, 4624]" = torch.ops.aten.exp.default(sub_1); sub_1 = None sum_1: "f32[1, 1, 1, 4624]" = torch.ops.aten.sum.dim_IntList(exp, [2], True) div: "f32[1, 1, 225, 4624]" = torch.ops.aten.div.Tensor(exp, sum_1); exp = sum_1 = None # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( permute_5: "f32[1, 1, 4624, 225]" = torch.ops.aten.permute.default(div, [0, 1, 3, 2]) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) clone_5: "f32[1, 1, 4624, 225]" = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format); permute_5 = None view_14: "f32[1040400]" = torch.ops.aten.reshape.default(clone_5, [1040400]); clone_5 = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, index_put: "f32[1, 1, 4624, 82, 82]" = torch.ops.aten.index_put_.default(full_default, [expand_4], view_14); full_default = expand_4 = view_14 = None # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( slice_9: "f32[1, 1, 4624, 68, 82]" = torch.ops.aten.slice.Tensor(index_put, 3, 7, -7); index_put = None slice_10: "f32[1, 1, 4624, 68, 68]" = torch.ops.aten.slice.Tensor(slice_9, 4, 7, -7); slice_9 = None clone_6: "f32[1, 1, 4624, 68, 68]" = torch.ops.aten.clone.default(slice_10, memory_format = torch.contiguous_format); slice_10 = None view_15: "f32[1, 1, 4624, 4624]" = torch.ops.aten.reshape.default(clone_6, [1, 1, 4624, 4624]); clone_6 = None # convert_element_type_6: "f16[1, 1, 4624, 4624]" = torch.ops.prims.convert_element_type.default(view_15, torch.float16); view_15 = None expand_5: "f16[1, 1, 4624, 4624]" = torch.ops.aten.expand.default(convert_element_type_6, [1, 1, 4624, 4624]); convert_element_type_6 = None view_16: "f16[1, 4624, 4624]" = torch.ops.aten.reshape.default(expand_5, [1, 4624, 4624]); expand_5 = None permute_6: "f16[1, 1, 4624, 256]" = torch.ops.aten.permute.default(arg4_1, [0, 1, 3, 2]); arg4_1 = None expand_6: "f16[1, 1, 4624, 256]" = torch.ops.aten.expand.default(permute_6, [1, 1, 4624, 256]); permute_6 = None view_17: "f16[1, 4624, 256]" = torch.ops.aten.reshape.default(expand_6, [1, 4624, 256]); expand_6 = None bmm_1: "f16[1, 4624, 256]" = torch.ops.aten.bmm.default(view_16, view_17); view_16 = view_17 = None view_18: "f16[1, 1, 4624, 256]" = torch.ops.aten.reshape.default(bmm_1, [1, 1, 4624, 256]); bmm_1 = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) convert_element_type_2: "f16[1, 1, 225, 4624]" = torch.ops.prims.convert_element_type.default(div, torch.float16) unsqueeze: "f16[1, 1, 225, 4624, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_2, 4); convert_element_type_2 = None permute: "f16[1, 1, 4624, 1, 225]" = torch.ops.aten.permute.default(unsqueeze, [0, 1, 3, 4, 2]); unsqueeze = None permute_2: "f16[4624, 225, 1, 1, 1]" = torch.ops.aten.permute.default(permute, [2, 4, 0, 1, 3]); permute = None view_1: "f16[1, 4624, 225]" = torch.ops.aten.reshape.default(permute_2, [1, 4624, 225]); permute_2 = None convert_element_type_3: "f16[1, 256, 225]" = torch.ops.prims.convert_element_type.default(arg3_1, torch.float16); arg3_1 = None unsqueeze_1: "f16[1, 256, 225, 1]" = torch.ops.aten.unsqueeze.default(convert_element_type_3, 3); convert_element_type_3 = None unsqueeze_2: "f16[1, 256, 225, 1, 1]" = torch.ops.aten.unsqueeze.default(unsqueeze_1, 4); unsqueeze_1 = None permute_1: "f16[1, 1, 1, 256, 225]" = torch.ops.aten.permute.default(unsqueeze_2, [3, 0, 4, 1, 2]); unsqueeze_2 = None permute_3: "f16[225, 1, 1, 256, 1]" = torch.ops.aten.permute.default(permute_1, [4, 0, 1, 3, 2]); permute_1 = None view_2: "f16[1, 225, 256]" = torch.ops.aten.reshape.default(permute_3, [1, 225, 256]); permute_3 = None bmm: "f16[1, 4624, 256]" = torch.ops.aten.bmm.default(view_1, view_2); view_1 = view_2 = None view_3: "f16[4624, 1, 1, 1, 256]" = torch.ops.aten.reshape.default(bmm, [4624, 1, 1, 1, 256]); bmm = None permute_4: "f16[1, 1, 4624, 256, 1]" = torch.ops.aten.permute.default(view_3, [2, 3, 0, 4, 1]); view_3 = None view_4: "f16[1, 1, 4624, 256]" = torch.ops.aten.reshape.default(permute_4, [1, 1, 4624, 256]); permute_4 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_3: "f16[1, 1, 4624, 256]" = torch.ops.aten.add.Tensor(view_18, view_4); view_18 = view_4 = None permute_7: "f16[4624, 1, 1, 256]" = torch.ops.aten.permute.default(add_3, [2, 0, 1, 3]); add_3 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: view_19: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(permute_7, [4624, 1, 256]); permute_7 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) view_20: "f16[4624, 256]" = torch.ops.aten.reshape.default(view_19, [4624, 256]); view_19 = None convert_element_type_10: "f16[256, 256]" = torch.ops.prims.convert_element_type.default(arg5_1, torch.float16); arg5_1 = None permute_8: "f16[256, 256]" = torch.ops.aten.permute.default(convert_element_type_10, [1, 0]); convert_element_type_10 = None addmm: "f16[4624, 256]" = torch.ops.aten.addmm.default(convert_element_type_9, view_20, permute_8); convert_element_type_9 = view_20 = permute_8 = None view_21: "f16[4624, 1, 256]" = torch.ops.aten.reshape.default(addmm, [4624, 1, 256]); addmm = None return (view_21, div, view_13)