class GraphModule(torch.nn.Module): def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_stack0_: "f16[1, 15, 15, s0, s1][225*s0*s1, 15*s0*s1, s0*s1, s1, 1]cuda:0", L_h_: "Sym(s10)", L_w_: "Sym(s11)", L_relative_emb_: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0", L_qk_mask_: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0", L_self_parameters_relative_emb_v_: "f32[1, 256, 225][57600, 225, 1]cuda:0", s6: "Sym(s6)", s7: "Sym(s7)", s8: "Sym(s8)", s9: "Sym(s9)", L_self_local_mask: "b8[1, 1, s6, s7, s8][0, 0, s9, s8, 1]cuda:0", L_self_last_size_2d_0_: "Sym(s10)", L_self_last_size_2d_1_: "Sym(s11)", L_v_: "f16[1, 1, 256, s10*s11][256*s10*s11, 256*s10*s11, s10*s11, 1]cuda:0", L_self_modules_projection_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_projection_parameters_bias_: "f32[256][1]cuda:0"): l_stack0_ = L_stack0_ l_h_ = L_h_ l_w_ = L_w_ l_relative_emb_ = L_relative_emb_ l_qk_mask_ = L_qk_mask_ l_self_parameters_relative_emb_v_ = L_self_parameters_relative_emb_v_ l_self_local_mask = L_self_local_mask l_self_last_size_2d_0_ = L_self_last_size_2d_0_ l_self_last_size_2d_1_ = L_self_last_size_2d_1_ l_v_ = L_v_ l_self_modules_projection_parameters_weight_ = L_self_modules_projection_parameters_weight_ l_self_modules_projection_parameters_bias_ = L_self_modules_projection_parameters_bias_ # File: /workspace/networks/layers/attention.py:347 in torch_dynamo_resume_in_forward_at_346, code: n, self.num_head, self.window_size * self.window_size, h * w) mul: "Sym(s10*s11)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( qk: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_stack0_.view(1, 1, 225, mul); l_stack0_ = mul = None # qk_1: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk + l_relative_emb_; qk = l_relative_emb_ = None # mul_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_qk_mask_ * 10000.0; l_qk_mask_ = None qk_1 -= mul_1; qk_2: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk_1; qk_1 = mul_1 = None # local_attn: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.softmax(qk_2, dim = 2); qk_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) local_attn_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.nn.functional.dropout(local_attn, 0.0, False, False); local_attn = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) agg_bias: "f16[1, 1, s0*s1, 256][256, 256, 256, 1]cuda:0" = torch.functional.einsum('bhwn,hcw->bhnc', local_attn_1, l_self_parameters_relative_emb_v_); l_self_parameters_relative_emb_v_ = None # File: /workspace/networks/layers/attention.py:391 in local2global, code: pad_height = height + 2 * self.max_dis add_1: "Sym(s10 + 14)" = l_h_ + 14 # File: /workspace/networks/layers/attention.py:392 in local2global, code: pad_width = width + 2 * self.max_dis add_2: "Sym(s11 + 14)" = l_w_ + 14 # File: /workspace/networks/layers/attention.py:417 in local2global, code: (batch_size, self.num_head, height * width, pad_height, pad_width), mul_2: "Sym(s10*s11)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( global_attn: "f32[1, 1, s10*s11, s10 + 14, s11 + 14][s10**2*s11**2 + 14*s10**2*s11 + 14*s10*s11**2 + 196*s10*s11, s10**2*s11**2 + 14*s10**2*s11 + 14*s10*s11**2 + 196*s10*s11, s10*s11 + 14*s10 + 14*s11 + 196, s11 + 14, 1]cuda:0" = torch.zeros((1, 1, mul_2, add_1, add_2), device = device(type='cuda', index=0)); mul_2 = add_1 = add_2 = None # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( transpose: "f32[1, 1, s0*s1, 225][225*s0*s1, 225*s0*s1, 1, s0*s1]cuda:0" = local_attn_1.transpose(-1, -2) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) reshape: "f32[225*s0*s1][1]cuda:0" = transpose.reshape(-1); transpose = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand: "b8[1, 1, s6, s7, s8][0, 0, s9, s8, 1]cuda:0" = l_self_local_mask.expand(1, 1, -1, -1, -1); l_self_local_mask = None global_attn[expand] = reshape; setitem = global_attn; expand = reshape = None # File: /workspace/networks/layers/attention.py:422 in local2global, code: global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, getitem_4: "f32[1, 1, s10*s11, s10, s11][s10**2*s11**2 + 14*s10**2*s11 + 14*s10*s11**2 + 196*s10*s11, s10**2*s11**2 + 14*s10**2*s11 + 14*s10*s11**2 + 196*s10*s11, s10*s11 + 14*s10 + 14*s11 + 196, s11 + 14, 1]cuda:0" = global_attn[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(7, -7, None), slice(7, -7, None))]; global_attn = None # File: /workspace/networks/layers/attention.py:425 in local2global, code: height * width, height * width) mul_3: "Sym(s10*s11)" = l_h_ * l_w_ mul_4: "Sym(s10*s11)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( global_attn_1: "f32[1, 1, s10*s11, s10*s11][s10**2*s11**2, s10**2*s11**2, s10*s11, 1]cuda:0" = getitem_4.reshape(1, 1, mul_3, mul_4); getitem_4 = mul_3 = mul_4 = None # transpose_1: "f16[1, 1, s10*s11, 256][256*s10*s11, 256*s10*s11, 1, s10*s11]cuda:0" = l_v_.transpose(-2, -1); l_v_ = None agg_value: "f16[1, 1, s10*s11, 256][256*s10*s11, 256*s10*s11, 256, 1]cuda:0" = global_attn_1 @ transpose_1; global_attn_1 = transpose_1 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_3: "f16[1, 1, s10*s11, 256][256*s10*s11, 256*s10*s11, 256, 1]cuda:0" = agg_value + agg_bias; agg_value = agg_bias = None permute: "f16[s10*s11, 1, 1, 256][256, 256*s10*s11, 256*s10*s11, 1]cuda:0" = add_3.permute(2, 0, 1, 3); add_3 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: mul_5: "Sym(s10*s11)" = l_h_ * l_w_; l_h_ = l_w_ = None output: "f16[s10*s11, 1, 256][256, 256*s10*s11, 1]cuda:0" = permute.reshape(mul_5, 1, 256); permute = mul_5 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) output_1: "f16[s10*s11, 1, 256][256, 256, 1]cuda:0" = torch._C._nn.linear(output, l_self_modules_projection_parameters_weight_, l_self_modules_projection_parameters_bias_); output = l_self_modules_projection_parameters_weight_ = l_self_modules_projection_parameters_bias_ = None return (output_1, local_attn_1)