class GraphModule(torch.nn.Module): def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_stack0_: "f16[1, 15, 15, s0, s1][225*s0*s1, 15*s0*s1, s0*s1, s1, 1]cuda:0", L_h_: "Sym(s2)", L_w_: "Sym(s3)", L_relative_emb_: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0", L_qk_mask_: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0", L_self_parameters_relative_emb_v_: "f32[1, 256, 225][57600, 225, 1]cuda:0", L_v_: "f16[1, 1, 256, s2*s3][256*s2*s3, 256*s2*s3, s2*s3, 1]cuda:0", L_self_modules_projection_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_projection_parameters_bias_: "f32[256][1]cuda:0"): l_stack0_ = L_stack0_ l_h_ = L_h_ l_w_ = L_w_ l_relative_emb_ = L_relative_emb_ l_qk_mask_ = L_qk_mask_ l_self_parameters_relative_emb_v_ = L_self_parameters_relative_emb_v_ l_v_ = L_v_ l_self_modules_projection_parameters_weight_ = L_self_modules_projection_parameters_weight_ l_self_modules_projection_parameters_bias_ = L_self_modules_projection_parameters_bias_ # File: /workspace/networks/layers/attention.py:347 in torch_dynamo_resume_in_forward_at_346, code: n, self.num_head, self.window_size * self.window_size, h * w) mul: "Sym(s2*s3)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( qk: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_stack0_.view(1, 1, 225, mul); l_stack0_ = mul = None # qk_1: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk + l_relative_emb_; qk = l_relative_emb_ = None # mul_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_qk_mask_ * 10000.0; l_qk_mask_ = None qk_1 -= mul_1; qk_2: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk_1; qk_1 = mul_1 = None # local_attn: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.softmax(qk_2, dim = 2); qk_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) local_attn_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.nn.functional.dropout(local_attn, 0.0, False, False); local_attn = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) agg_bias: "f16[1, 1, s0*s1, 256][256, 256, 256, 1]cuda:0" = torch.functional.einsum('bhwn,hcw->bhnc', local_attn_1, l_self_parameters_relative_emb_v_); l_self_parameters_relative_emb_v_ = None # File: /workspace/networks/layers/attention.py:391 in local2global, code: pad_height = height + 2 * self.max_dis add_1: "Sym(s2 + 14)" = l_h_ + 14 # File: /workspace/networks/layers/attention.py:392 in local2global, code: pad_width = width + 2 * self.max_dis add_2: "Sym(s3 + 14)" = l_w_ + 14 # File: /workspace/networks/layers/attention.py:399 in local2global, code: torch.arange(0, pad_height, device=local_attn.device), arange: "i64[s2 + 14][1]cuda:0" = torch.arange(0, add_1, device = device(type='cuda', index=0)) # File: /workspace/networks/layers/attention.py:400 in local2global, code: torch.arange(0, pad_width, device=local_attn.device) arange_1: "i64[s3 + 14][1]cuda:0" = torch.arange(0, add_2, device = device(type='cuda', index=0)) # File: /workspace/networks/layers/attention.py:398 in local2global, code: ky, kx = torch.meshgrid([ meshgrid = torch.functional.meshgrid([arange, arange_1]); arange = arange_1 = None ky: "i64[s2 + 14, s3 + 14][1, 0]cuda:0" = meshgrid[0] kx: "i64[s2 + 14, s3 + 14][0, 1]cuda:0" = meshgrid[1]; meshgrid = None # File: /workspace/networks/layers/attention.py:403 in local2global, code: torch.arange(0, height, device=local_attn.device), arange_2: "i64[s2][1]cuda:0" = torch.arange(0, l_h_, device = device(type='cuda', index=0)) # File: /workspace/networks/layers/attention.py:404 in local2global, code: torch.arange(0, width, device=local_attn.device) arange_3: "i64[s3][1]cuda:0" = torch.arange(0, l_w_, device = device(type='cuda', index=0)) # File: /workspace/networks/layers/attention.py:402 in local2global, code: qy, qx = torch.meshgrid([ meshgrid_1 = torch.functional.meshgrid([arange_2, arange_3]); arange_2 = arange_3 = None qy: "i64[s2, s3][1, 0]cuda:0" = meshgrid_1[0] qx: "i64[s2, s3][0, 1]cuda:0" = meshgrid_1[1]; meshgrid_1 = None # File: /workspace/networks/layers/attention.py:407 in local2global, code: offset_y = qy.reshape(-1, 1) - ky.reshape(1, -1) + self.max_dis reshape: "i64[s2*s3, 1][1, 1]cuda:0" = qy.reshape(-1, 1); qy = None reshape_1: "i64[1, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = ky.reshape(1, -1); ky = None sub: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = reshape - reshape_1; reshape = reshape_1 = None offset_y: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = sub + 7; sub = None # File: /workspace/networks/layers/attention.py:408 in local2global, code: offset_x = qx.reshape(-1, 1) - kx.reshape(1, -1) + self.max_dis reshape_2: "i64[s2*s3, 1][1, 1]cuda:0" = qx.reshape(-1, 1); qx = None reshape_3: "i64[1, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = kx.reshape(1, -1); kx = None sub_1: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = reshape_2 - reshape_3; reshape_2 = reshape_3 = None offset_x: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = sub_1 + 7; sub_1 = None # File: /workspace/networks/layers/attention.py:410 in local2global, code: local_mask = (offset_y.abs() <= self.max_dis) & (offset_x.abs() <= abs_1: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = offset_y.abs(); offset_y = None le: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = abs_1 <= 7; abs_1 = None abs_2: "i64[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = offset_x.abs(); offset_x = None le_1: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = abs_2 <= 7; abs_2 = None local_mask: "b8[s2*s3, s2*s3 + 14*s2 + 14*s3 + 196][s2*s3 + 14*s2 + 14*s3 + 196, 1]cuda:0" = le & le_1; le = le_1 = None # File: /workspace/networks/layers/attention.py:412 in local2global, code: local_mask = local_mask.view(1, 1, height * width, pad_height, mul_2: "Sym(s2*s3)" = l_h_ * l_w_ local_mask_1: "b8[1, 1, s2*s3, s2 + 14, s3 + 14][s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2*s3 + 14*s2 + 14*s3 + 196, s3 + 14, 1]cuda:0" = local_mask.view(1, 1, mul_2, add_1, add_2); local_mask = mul_2 = None # File: /workspace/networks/layers/attention.py:417 in local2global, code: (batch_size, self.num_head, height * width, pad_height, pad_width), mul_3: "Sym(s2*s3)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( global_attn: "f32[1, 1, s2*s3, s2 + 14, s3 + 14][s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2*s3 + 14*s2 + 14*s3 + 196, s3 + 14, 1]cuda:0" = torch.zeros((1, 1, mul_3, add_1, add_2), device = device(type='cuda', index=0)); mul_3 = add_1 = add_2 = None # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( transpose: "f32[1, 1, s0*s1, 225][225*s0*s1, 225*s0*s1, 1, s0*s1]cuda:0" = local_attn_1.transpose(-1, -2) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) reshape_4: "f32[225*s0*s1][1]cuda:0" = transpose.reshape(-1); transpose = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand: "b8[1, 1, s2*s3, s2 + 14, s3 + 14][s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2*s3 + 14*s2 + 14*s3 + 196, s3 + 14, 1]cuda:0" = local_mask_1.expand(1, 1, -1, -1, -1) global_attn[expand] = reshape_4; setitem = global_attn; expand = reshape_4 = None # File: /workspace/networks/layers/attention.py:422 in local2global, code: global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, getitem_8: "f32[1, 1, s2*s3, s2, s3][s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2**2*s3**2 + 14*s2**2*s3 + 14*s2*s3**2 + 196*s2*s3, s2*s3 + 14*s2 + 14*s3 + 196, s3 + 14, 1]cuda:0" = global_attn[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(7, -7, None), slice(7, -7, None))]; global_attn = None # File: /workspace/networks/layers/attention.py:425 in local2global, code: height * width, height * width) mul_4: "Sym(s2*s3)" = l_h_ * l_w_ mul_5: "Sym(s2*s3)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( global_attn_1: "f32[1, 1, s2*s3, s2*s3][s2**2*s3**2, s2**2*s3**2, s2*s3, 1]cuda:0" = getitem_8.reshape(1, 1, mul_4, mul_5); getitem_8 = mul_4 = mul_5 = None # transpose_1: "f16[1, 1, s2*s3, 256][256*s2*s3, 256*s2*s3, 1, s2*s3]cuda:0" = l_v_.transpose(-2, -1); l_v_ = None agg_value: "f16[1, 1, s2*s3, 256][256*s2*s3, 256*s2*s3, 256, 1]cuda:0" = global_attn_1 @ transpose_1; global_attn_1 = transpose_1 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_5: "f16[1, 1, s2*s3, 256][256*s2*s3, 256*s2*s3, 256, 1]cuda:0" = agg_value + agg_bias; agg_value = agg_bias = None permute: "f16[s2*s3, 1, 1, 256][256, 256*s2*s3, 256*s2*s3, 1]cuda:0" = add_5.permute(2, 0, 1, 3); add_5 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: mul_6: "Sym(s2*s3)" = l_h_ * l_w_; l_h_ = l_w_ = None output: "f16[s2*s3, 1, 256][256, 256*s2*s3, 1]cuda:0" = permute.reshape(mul_6, 1, 256); permute = mul_6 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) output_1: "f16[s2*s3, 1, 256][256, 256, 1]cuda:0" = torch._C._nn.linear(output, l_self_modules_projection_parameters_weight_, l_self_modules_projection_parameters_bias_); output = l_self_modules_projection_parameters_weight_ = l_self_modules_projection_parameters_bias_ = None return (output_1, local_attn_1, local_mask_1)