class GraphModule(torch.nn.Module): def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_stack0_: "f16[1, 15, 15, s0, s1][225*s0*s1, 15*s0*s1, s0*s1, s1, 1]cuda:0", L_h_: "Sym(68)", L_w_: "Sym(68)", s5: "Sym(s5)", L_relative_emb_: "f16[1, 1, 225, s0*s1][0, 0, s5, 1]cuda:0", L_qk_mask_: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0", L_self_parameters_relative_emb_v_: "f32[1, 256, 225][57600, 225, 1]cuda:0", L_self_local_mask: "b8[1, 1, 4624, 82, 82][0, 0, 6784, 82, 1]cuda:0", L_v_: "f16[1, 1, 256, 4624][1183744, 1183744, 4624, 1]cuda:0", L_self_modules_projection_parameters_weight_: "f32[256, 256][256, 1]cuda:0", L_self_modules_projection_parameters_bias_: "f32[256][1]cuda:0"): l_stack0_ = L_stack0_ l_h_ = L_h_ l_w_ = L_w_ l_relative_emb_ = L_relative_emb_ l_qk_mask_ = L_qk_mask_ l_self_parameters_relative_emb_v_ = L_self_parameters_relative_emb_v_ l_self_local_mask = L_self_local_mask l_v_ = L_v_ l_self_modules_projection_parameters_weight_ = L_self_modules_projection_parameters_weight_ l_self_modules_projection_parameters_bias_ = L_self_modules_projection_parameters_bias_ # File: /workspace/networks/layers/attention.py:347 in torch_dynamo_resume_in_forward_at_346, code: n, self.num_head, self.window_size * self.window_size, h * w) mul: "Sym(4624)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:346 in torch_dynamo_resume_in_forward_at_346, code: qk = self.correlation_sampler(q, k).view( qk: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_stack0_.view(1, 1, 225, mul); l_stack0_ = mul = None # qk_1: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk + l_relative_emb_; qk = l_relative_emb_ = None # mul_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = l_qk_mask_ * 10000.0; l_qk_mask_ = None qk_1 -= mul_1; qk_2: "f16[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = qk_1; qk_1 = mul_1 = None # local_attn: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.softmax(qk_2, dim = 2); qk_2 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/dropout.py:59 in forward, code: return F.dropout(input, self.p, self.training, self.inplace) local_attn_1: "f32[1, 1, 225, s0*s1][225*s0*s1, 225*s0*s1, s0*s1, 1]cuda:0" = torch.nn.functional.dropout(local_attn, 0.0, False, False); local_attn = None # File: /workspace/networks/layers/attention.py:370 in torch_dynamo_resume_in_forward_at_346, code: self.relative_emb_v) agg_bias: "f16[1, 1, s0*s1, 256][256, 256, 256, 1]cuda:0" = torch.functional.einsum('bhwn,hcw->bhnc', local_attn_1, l_self_parameters_relative_emb_v_); l_self_parameters_relative_emb_v_ = None # File: /workspace/networks/layers/attention.py:391 in local2global, code: pad_height = height + 2 * self.max_dis add_1: "Sym(82)" = l_h_ + 14 # File: /workspace/networks/layers/attention.py:392 in local2global, code: pad_width = width + 2 * self.max_dis add_2: "Sym(82)" = l_w_ + 14 # File: /workspace/networks/layers/attention.py:417 in local2global, code: (batch_size, self.num_head, height * width, pad_height, pad_width), mul_2: "Sym(4624)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:416 in local2global, code: global_attn = torch.zeros( global_attn: "f32[1, 1, 4624, 82, 82][31091776, 31091776, 6724, 82, 1]cuda:0" = torch.zeros((1, 1, mul_2, add_1, add_2), device = device(type='cuda', index=0)); mul_2 = add_1 = add_2 = None # File: /workspace/networks/layers/attention.py:420 in local2global, code: -1, -1, -1)] = local_attn.transpose( transpose: "f32[1, 1, s0*s1, 225][225*s0*s1, 225*s0*s1, 1, s0*s1]cuda:0" = local_attn_1.transpose(-1, -2) # File: /workspace/networks/layers/attention.py:421 in local2global, code: -1, -2).reshape(-1) reshape: "f32[225*s0*s1][1]cuda:0" = transpose.reshape(-1); transpose = None # File: /workspace/networks/layers/attention.py:419 in local2global, code: global_attn[local_mask.expand(batch_size, self.num_head, expand: "b8[1, 1, 4624, 82, 82][0, 0, 6784, 82, 1]cuda:0" = l_self_local_mask.expand(1, 1, -1, -1, -1); l_self_local_mask = None global_attn[expand] = reshape; setitem = global_attn; expand = reshape = None # File: /workspace/networks/layers/attention.py:422 in local2global, code: global_attn = global_attn[:, :, :, self.max_dis:-self.max_dis, getitem_4: "f32[1, 1, 4624, 68, 68][31091776, 31091776, 6724, 82, 1]cuda:0" = global_attn[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(7, -7, None), slice(7, -7, None))]; global_attn = None # File: /workspace/networks/layers/attention.py:425 in local2global, code: height * width, height * width) mul_3: "Sym(4624)" = l_h_ * l_w_ mul_4: "Sym(4624)" = l_h_ * l_w_ # File: /workspace/networks/layers/attention.py:423 in local2global, code: self.max_dis:-self.max_dis].reshape( global_attn_1: "f32[1, 1, 4624, 4624][21381376, 21381376, 4624, 1]cuda:0" = getitem_4.reshape(1, 1, mul_3, mul_4); getitem_4 = mul_3 = mul_4 = None # transpose_1: "f16[1, 1, 4624, 256][1183744, 1183744, 1, 4624]cuda:0" = l_v_.transpose(-2, -1); l_v_ = None agg_value: "f16[1, 1, 4624, 256][1183744, 1183744, 256, 1]cuda:0" = global_attn_1 @ transpose_1; global_attn_1 = transpose_1 = None # File: /workspace/networks/layers/attention.py:378 in torch_dynamo_resume_in_forward_at_346, code: 3).reshape(h * w, n, c) add_3: "f16[1, 1, 4624, 256][1183744, 1183744, 256, 1]cuda:0" = agg_value + agg_bias; agg_value = agg_bias = None permute: "f16[4624, 1, 1, 256][256, 1183744, 1183744, 1]cuda:0" = add_3.permute(2, 0, 1, 3); add_3 = None # File: /workspace/networks/layers/attention.py:379 in torch_dynamo_resume_in_forward_at_346, code: else: mul_5: "Sym(4624)" = l_h_ * l_w_; l_h_ = l_w_ = None output: "f16[4624, 1, 256][256, 1183744, 1]cuda:0" = permute.reshape(mul_5, 1, 256); permute = mul_5 = None # File: /opt/conda/lib/python3.11/site-packages/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) output_1: "f16[4624, 1, 256][256, 256, 1]cuda:0" = torch._C._nn.linear(output, l_self_modules_projection_parameters_weight_, l_self_modules_projection_parameters_bias_); output = l_self_modules_projection_parameters_weight_ = l_self_modules_projection_parameters_bias_ = None return (output_1, local_attn_1)