# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align

from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_root/ut/cutum24mmaatgtfsv4woxycqlzbsdogt2u54xhctaoi4vh3kdezb.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 3551232
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')

import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream


# kernel path: /tmp/torchinductor_root/5e/c5egziobbbb7h3rfbwmjwo2duj6cruzzve6uugmk3p43rvyyzsbe.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8192], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6144
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/nv/cnvgdx7r4kd34wufst4iduuf34fhbxltx3bb4j6q4f62cqpgs6pw.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[128], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/qw/cqw52v5yidn33sc72vds5xzut4fwzpkiznabfnmbghkdfdx4dv2c.py
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_2 => clone, convert_element_type_3, var_mean
triton_red_fused__to_copy_native_layer_norm_3 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp5_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp5_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp5_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tmp0 + tmp1
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
        tmp5_mean_next, tmp5_m2_next, tmp5_weight_next = triton_helpers.welford_reduce(
            tmp4, tmp5_mean, tmp5_m2, tmp5_weight, roffset == 0
        )
        tmp5_mean = tl.where(rmask & xmask, tmp5_mean_next, tmp5_mean)
        tmp5_m2 = tl.where(rmask & xmask, tmp5_m2_next, tmp5_m2)
        tmp5_weight = tl.where(rmask & xmask, tmp5_weight_next, tmp5_weight)
    tmp5_tmp, tmp6_tmp, tmp7_tmp = triton_helpers.welford(
        tmp5_mean, tmp5_m2, tmp5_weight, 1
    )
    tmp5 = tmp5_tmp[:, None]
    tmp6 = tmp6_tmp[:, None]
    tmp7 = tmp7_tmp[:, None]
    tl.store(out_ptr0 + (x0), tmp5, xmask)
    tl.store(out_ptr1 + (x0), tmp6, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/z7/cz7aixplgjdoqi3iyr5pyvohfufykasw4xp6tdzqkrnniybuo3jw.py
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_2 => add, add_1, clone, convert_element_type_3, mul, mul_1, rsqrt, sub, var_mean
triton_poi_fused__to_copy_native_layer_norm_4 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_native_layer_norm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9469952
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x1 = (xindex // 73984)
    x0 = xindex % 73984
    tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last').to(tl.float32)
    tmp4 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
    tmp6 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
    tmp13 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last')
    tmp15 = tl.load(in_ptr5 + (x1), None, eviction_policy='evict_last')
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2.to(tl.float32)
    tmp5 = tmp3 - tmp4
    tmp7 = 128.0
    tmp8 = tmp6 / tmp7
    tmp9 = 1e-05
    tmp10 = tmp8 + tmp9
    tmp11 = libdevice.rsqrt(tmp10)
    tmp12 = tmp5 * tmp11
    tmp14 = tmp12 * tmp13
    tmp16 = tmp14 + tmp15
    tl.store(out_ptr0 + (x2), tmp16, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/zd/czdqvqzahblfj2alw5l62ev6piz5bnkywru3bk67m4u7dfxo6ukh.py
# Source Nodes: [x_7], Original ATen: [aten.native_layer_norm]
# x_7 => var_mean_1
triton_red_fused_native_layer_norm_5 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp2_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp2_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp2_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
        tmp2_mean_next, tmp2_m2_next, tmp2_weight_next = triton_helpers.welford_reduce(
            tmp1, tmp2_mean, tmp2_m2, tmp2_weight, roffset == 0
        )
        tmp2_mean = tl.where(rmask & xmask, tmp2_mean_next, tmp2_mean)
        tmp2_m2 = tl.where(rmask & xmask, tmp2_m2_next, tmp2_m2)
        tmp2_weight = tl.where(rmask & xmask, tmp2_weight_next, tmp2_weight)
    tmp2_tmp, tmp3_tmp, tmp4_tmp = triton_helpers.welford(
        tmp2_mean, tmp2_m2, tmp2_weight, 1
    )
    tmp2 = tmp2_tmp[:, None]
    tmp3 = tmp3_tmp[:, None]
    tmp4 = tmp4_tmp[:, None]
    tl.store(out_ptr0 + (x0), tmp2, xmask)
    tl.store(out_ptr1 + (x0), tmp3, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/uz/cuzrg4rqsfxddxjfergppynngxa5vajxfugeyyhxthnn5pkepo7w.py
# Source Nodes: [linear], Original ATen: [aten._to_copy]
# linear => convert_element_type_8
triton_poi_fused__to_copy_6 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[131072, 128], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 74529
    xnumel = 128
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    y0 = yindex % 49
    y1 = (yindex // 49)
    x2 = xindex
    y3 = yindex
    tmp0 = (7*(y1 // 39)) + (y0 // 7)
    tmp1 = tl.full([1, 1], 272, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (7*(y1 % 39)) + (y0 % 7)
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + ((272*((((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7)) // 272) % 272)) + (73984*x2) + (((7*(y1 % 39)) + (y0 % 7)) % 272)), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp7 = tl.load(in_ptr1 + (tl.broadcast_to((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7), [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp8 = tmp6 - tmp7
    tmp9 = tl.load(in_ptr2 + (tl.broadcast_to((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7), [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp10 = 128.0
    tmp11 = tmp9 / tmp10
    tmp12 = 1e-05
    tmp13 = tmp11 + tmp12
    tmp14 = libdevice.rsqrt(tmp13)
    tmp15 = tmp8 * tmp14
    tmp16 = tl.load(in_ptr3 + (tl.broadcast_to(x2, [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp17 = tmp15 * tmp16
    tmp18 = tl.load(in_ptr4 + (tl.broadcast_to(x2, [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp19 = tmp17 + tmp18
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp5, tmp19, tmp20)
    tmp22 = tmp21.to(tl.float32)
    tl.store(out_ptr0 + (x2 + (128*y3)), tmp22, xmask & ymask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/27/c27kmexuvaczehtobz6jqsbfjvi5xgmlmysk5r6xkxskrs6dfhtn.py
# Source Nodes: [linear], Original ATen: [aten._to_copy]
# linear => convert_element_type_7
triton_poi_fused__to_copy_7 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[65536], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 49152
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/gv/cgvv6tj3pm4btpyl5nxivtckwp4tnmqcbgk3oirzkoxwjd7w4qhy.py
# Source Nodes: [attn, q_1], Original ATen: [aten.clone, aten.mul]
# attn => clone_4
# q_1 => mul_6
triton_poi_fused_clone_mul_8 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9539712
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 4
    x3 = (xindex // 6272)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (384*x1) + (18816*x3)), xmask).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), xmask, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = 0.1767766952966369
    tmp5 = tmp3 * tmp4
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/3d/c3dgt2c4b5437df63yihnutvowcnecyibipqaqopntg6wcxwli7t.py
# Source Nodes: [attn], Original ATen: [aten.clone]
# attn => clone_5
triton_poi_fused_clone_9 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[262144, 64], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 194688
    xnumel = 49
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x3 = xindex
    y2 = (yindex // 128)
    y4 = yindex % 128
    y0 = yindex % 32
    y5 = (yindex // 32)
    tmp0 = tl.load(in_ptr0 + (128 + y4 + (384*x3) + (18816*y2)), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (128 + y4), ymask, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask & ymask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/ja/cjajea6ta7qkhr6lfq7dh5ukfc3b2pbysx55ye2fbw22b5lnzzve.py
# Source Nodes: [attn_1, attn_2, matmul_1], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_1 => add_4
# attn_2 => amax, div_2, exp, sub_3, sum_1
# matmul_1 => convert_element_type_14
triton_per_fused__softmax__to_copy_add_10 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[524288, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 298116
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 4
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
    tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
    tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
    tmp14 = tmp9 - tmp13
    tmp15 = tl_math.exp(tmp14)
    tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
    tmp18 = tl.where(rmask & xmask, tmp16, 0)
    tmp19 = tl.sum(tmp18, 1)[:, None]
    tmp20 = tmp15 / tmp19
    tmp21 = tmp20.to(tl.float32)
    tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/y3/cy3breg2mghmf5ikb42jodayxixk2k3gmcadoiwf3jyrnnxwruwr.py
# Source Nodes: [matmul_1], Original ATen: [aten.clone]
# matmul_1 => clone_8
triton_poi_fused_clone_11 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9539712
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 4
    x3 = (xindex // 6272)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (256 + x0 + (32*x2) + (384*x1) + (18816*x3)), xmask).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (256 + x0 + (32*x2)), xmask, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/wx/cwxwgfbkicmgjn2axopqk6bihn72jskwcbv4lwok4w227xf3wkxj.py
# Source Nodes: [x_11], Original ATen: [aten.clone]
# x_11 => clone_9
triton_poi_fused_clone_12 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9539712
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 4
    x2 = (xindex // 128) % 49
    x3 = (xindex // 6272)
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (6272*x3)), xmask).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/e5/ce5swurtaidsanunlofocegeahut6uzzovemh7x35gwa2mv3uvbw.py
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
# x_12 => convert_element_type_18
triton_poi_fused__to_copy_13 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16384], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 16384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/6a/c6ak2mlzd7gs74dn6lkznsahwca2ehhzdludsreoi74bucu7ieca.py
# Source Nodes: [layer_norm_2, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_2 => add_6, add_7, mul_7, mul_8, rsqrt_2, sub_4, var_mean_2
# x_18 => add_5
# x_19 => convert_element_type_24
triton_red_fused__to_copy_add_native_layer_norm_14 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_last', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp4.to(tl.float32)
        tmp6 = tmp0 + tmp5
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp12 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp15.to(tl.float32)
        tmp17 = tmp11 + tmp16
        tmp18 = tmp17 - tmp8
        tmp19 = 128.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr2 + (r1 + (128*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/mj/cmjiy7nnzibomx47pwcnzzevlowthzxr6o6uin5k2sbiqh237ico.py
# Source Nodes: [x_19], Original ATen: [aten._to_copy]
# x_19 => convert_element_type_23
triton_poi_fused__to_copy_15 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[65536], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 65536
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/u7/cu7c5ukagtxwy7y4iiqkp65cte74cratlcairghkion2ijccm4hb.py
# Source Nodes: [x_20], Original ATen: [aten.gelu]
# x_20 => add_8, convert_element_type_28, convert_element_type_29, erf, mul_10, mul_11, mul_9
triton_poi_fused_gelu_16 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[67108864], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 37879808
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 512
    tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = tmp3.to(tl.float32)
    tmp5 = 0.5
    tmp6 = tmp4 * tmp5
    tmp7 = 0.7071067811865476
    tmp8 = tmp4 * tmp7
    tmp9 = libdevice.erf(tmp8)
    tmp10 = 1.0
    tmp11 = tmp9 + tmp10
    tmp12 = tmp6 * tmp11
    tmp13 = tmp12.to(tl.float32)
    tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/ce/ccebhjg73upjqgd74bporiec6gb6m46ufzq5fs4uybxgkhzqobna.py
# Source Nodes: [x_18, x_24, x_25], Original ATen: [aten.add, aten.native_layer_norm]
# x_18 => add_5
# x_24 => add_9
# x_25 => var_mean_3
triton_red_fused_add_native_layer_norm_17 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_native_layer_norm_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp14_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp14_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp14_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp7 = tl.load(in_ptr3 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp8 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp4.to(tl.float32)
        tmp6 = tmp0 + tmp5
        tmp9 = tmp8.to(tl.float32)
        tmp10 = tmp7 + tmp9
        tmp11 = tmp10.to(tl.float32)
        tmp12 = tmp6 + tmp11
        tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
        tmp14_mean_next, tmp14_m2_next, tmp14_weight_next = triton_helpers.welford_reduce(
            tmp13, tmp14_mean, tmp14_m2, tmp14_weight, roffset == 0
        )
        tmp14_mean = tl.where(rmask & xmask, tmp14_mean_next, tmp14_mean)
        tmp14_m2 = tl.where(rmask & xmask, tmp14_m2_next, tmp14_m2)
        tmp14_weight = tl.where(rmask & xmask, tmp14_weight_next, tmp14_weight)
        tl.store(out_ptr0 + (r1 + (128*x0)), tmp12, rmask & xmask)
    tmp14_tmp, tmp15_tmp, tmp16_tmp = triton_helpers.welford(
        tmp14_mean, tmp14_m2, tmp14_weight, 1
    )
    tmp14 = tmp14_tmp[:, None]
    tmp15 = tmp15_tmp[:, None]
    tmp16 = tmp16_tmp[:, None]
    tl.store(out_ptr1 + (x0), tmp14, xmask)
    tl.store(out_ptr2 + (x0), tmp15, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/bd/cbdbrul6cgbc725lx7noa5ulqmr2r6ohwoqlh54cc2oipmh3exwr.py
# Source Nodes: [shifted_x, x_27], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x => index_1, index_2
# x_27 => constant_pad_nd_1
triton_poi_fused_constant_pad_nd_roll_18 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9539712
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = (xindex // 34944)
    x1 = (xindex // 128) % 273
    x0 = xindex % 128
    x4 = xindex
    tmp0 = (3 + x2) % 273
    tmp1 = tl.full([1], 272, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (3 + x1) % 273
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + (x0 + (128*((3 + x1) % 273)) + (34816*((3 + x2) % 273))), tmp5 & xmask, other=0.0)
    tmp7 = tl.load(in_ptr1 + ((272*((3 + x2) % 273)) + ((3 + x1) % 273)), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
    tmp8 = tmp6 - tmp7
    tmp9 = tl.load(in_ptr2 + ((272*((3 + x2) % 273)) + ((3 + x1) % 273)), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
    tmp10 = 128.0
    tmp11 = tmp9 / tmp10
    tmp12 = 1e-05
    tmp13 = tmp11 + tmp12
    tmp14 = libdevice.rsqrt(tmp13)
    tmp15 = tmp8 * tmp14
    tmp16 = tl.load(in_ptr3 + (x0), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
    tmp17 = tmp15 * tmp16
    tmp18 = tl.load(in_ptr4 + (x0), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
    tmp19 = tmp17 + tmp18
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp5, tmp19, tmp20)
    tl.store(out_ptr0 + (x4), tmp21, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/wf/cwfhqs2zm6bsjsrw2tiokspft5espgjofwlyywwnhdmmbyrpctxz.py
# Source Nodes: [linear_4], Original ATen: [aten._to_copy]
# linear_4 => convert_element_type_37
triton_poi_fused__to_copy_19 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9539712
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 128
    x1 = (xindex // 128) % 49
    x2 = (xindex // 6272)
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (128*(x1 % 7)) + (896*(x2 % 39)) + (34944*(x1 // 7)) + (244608*(x2 // 39))), xmask)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/or/corkqb2j6hymgta3qqablpxtt4psncgc2x5wtkcbslwfpukdab7t.py
# Source Nodes: [img_mask, setitem, setitem_1, setitem_2, setitem_3, setitem_4], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask => full
# setitem => copy, lift_fresh_copy_2
# setitem_1 => copy_1, lift_fresh_copy_3
# setitem_2 => copy_2, lift_fresh_copy_4
# setitem_3 => copy_3, full_default, lift_fresh_copy_5
# setitem_4 => copy_4, lift_fresh_copy_6
triton_poi_fused_fill_lift_fresh_slice_zeros_20 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[131072], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_20', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 74529
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 273)
    x0 = xindex % 273
    x2 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 266, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 270, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tmp2 & tmp4
    tmp6 = x0
    tmp7 = tmp6 < tmp1
    tmp8 = tmp7 & tmp5
    tmp9 = 3.0
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp8, tmp9, tmp10)
    tmp12 = 0.0
    tmp13 = tl.where(tmp7, tmp11, tmp12)
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp5, tmp13, tmp14)
    tmp16 = tmp0 < tmp1
    tmp17 = tmp6 >= tmp3
    tmp18 = tmp17 & tmp16
    tmp19 = 2.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = tmp16 & tmp16
    tmp23 = tmp6 >= tmp1
    tmp24 = tmp6 < tmp3
    tmp25 = tmp23 & tmp24
    tmp26 = tmp25 & tmp22
    tmp27 = 1.0
    tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
    tmp29 = tl.where(tmp26, tmp27, tmp28)
    tmp30 = tmp16 & tmp22
    tmp31 = tmp7 & tmp30
    tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp33 = tl.where(tmp31, tmp12, tmp32)
    tmp34 = tl.where(tmp7, tmp33, tmp12)
    tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
    tmp36 = tl.where(tmp30, tmp34, tmp35)
    tmp37 = tl.where(tmp16, tmp36, tmp12)
    tmp38 = tl.where(tmp25, tmp29, tmp37)
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp22, tmp38, tmp39)
    tmp41 = tmp7 & tmp22
    tmp42 = tl.where(tmp41, tmp12, tmp32)
    tmp43 = tl.where(tmp7, tmp42, tmp12)
    tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
    tmp45 = tl.where(tmp22, tmp43, tmp44)
    tmp46 = tl.where(tmp16, tmp45, tmp12)
    tmp47 = tl.where(tmp16, tmp40, tmp46)
    tmp48 = tl.where(tmp17, tmp21, tmp47)
    tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
    tmp50 = tl.where(tmp16, tmp48, tmp49)
    tmp51 = tmp25 & tmp16
    tmp52 = tl.where(tmp51, tmp27, tmp28)
    tmp53 = tl.where(tmp25, tmp52, tmp46)
    tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
    tmp55 = tl.where(tmp16, tmp53, tmp54)
    tmp56 = tmp7 & tmp16
    tmp57 = tl.where(tmp56, tmp12, tmp32)
    tmp58 = tl.where(tmp7, tmp57, tmp12)
    tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
    tmp60 = tl.where(tmp16, tmp58, tmp59)
    tmp61 = tl.where(tmp16, tmp60, tmp12)
    tmp62 = tl.where(tmp16, tmp55, tmp61)
    tmp63 = tl.where(tmp16, tmp50, tmp62)
    tmp64 = tl.where(tmp5, tmp15, tmp63)
    tmp65 = tmp25 & tmp5
    tmp66 = 4.0
    tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
    tmp68 = tl.where(tmp65, tmp66, tmp67)
    tmp69 = tl.where(tmp25, tmp68, tmp64)
    tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
    tmp71 = tl.where(tmp5, tmp69, tmp70)
    tmp72 = tl.where(tmp5, tmp71, tmp64)
    tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/js/cjs7cpqsrlg6uha5jw3ddfb244ovf73mv7lc45xezom63apdnnpq.py
# Source Nodes: [setitem_8], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_8 => copy_8, lift_fresh_copy_10
triton_poi_fused_fill_lift_fresh_21 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[1024], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 819
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 273
    x1 = (xindex // 273)
    x2 = xindex
    tmp55 = tl.load(in_ptr0 + (73710 + x2), xmask)
    tmp0 = x0
    tmp1 = tl.full([1], 270, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = 8.0
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = 270 + x1
    tmp7 = tmp6 >= tmp1
    tmp8 = tl.full([1], 266, tl.int64)
    tmp9 = tmp0 >= tmp8
    tmp10 = tmp0 < tmp1
    tmp11 = tmp9 & tmp10
    tmp12 = tmp11 & tmp7
    tmp13 = 7.0
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp12, tmp13, tmp14)
    tmp16 = tmp7 & tmp7
    tmp17 = tmp0 < tmp8
    tmp18 = tmp17 & tmp16
    tmp19 = 6.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = 0.0
    tmp23 = tl.where(tmp17, tmp21, tmp22)
    tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
    tmp25 = tl.where(tmp16, tmp23, tmp24)
    tmp26 = tmp6 >= tmp8
    tmp27 = tmp6 < tmp1
    tmp28 = tmp26 & tmp27
    tmp29 = tmp28 & tmp7
    tmp30 = tmp2 & tmp29
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_ptr0 + (73710 + x2), tmp29 & xmask, other=0.0)
    tmp35 = tl.where(tmp2, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp29, tmp35, tmp36)
    tmp38 = tl.load(in_ptr0 + (73710 + x2), tmp7 & xmask, other=0.0)
    tmp39 = tl.where(tmp28, tmp37, tmp38)
    tmp40 = tl.where(tmp7, tmp25, tmp39)
    tmp41 = tl.where(tmp11, tmp15, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp7, tmp41, tmp42)
    tmp44 = tmp17 & tmp7
    tmp45 = tl.where(tmp44, tmp19, tmp20)
    tmp46 = tl.where(tmp17, tmp45, tmp22)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp7, tmp46, tmp47)
    tmp49 = tmp2 & tmp28
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_ptr0 + (73710 + x2), tmp28 & xmask, other=0.0)
    tmp52 = tl.where(tmp2, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp28, tmp52, tmp53)
    tmp56 = tl.where(tmp28, tmp54, tmp55)
    tmp57 = tl.where(tmp7, tmp48, tmp56)
    tmp58 = tl.where(tmp7, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/u5/cu5hdqskxvnzge6kx3b3irwa2qsnzqulsyxqs3z77v5oyt3nwued.py
# Source Nodes: [setitem_5, setitem_6, setitem_7], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_5 => copy_5, lift_fresh_copy_7
# setitem_6 => copy_6, full_default_1, lift_fresh_copy_8
# setitem_7 => copy_7, lift_fresh_copy_9
triton_poi_fused_fill_lift_fresh_slice_22 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[131072], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_22', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 74529
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 273)
    x2 = xindex
    x0 = xindex % 273
    tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp0 = x1
    tmp1 = tl.full([1], 270, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.load(in_ptr0 + ((-73710) + x2), tmp2 & xmask, other=0.0)
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = x0
    tmp7 = tl.full([1], 266, tl.int64)
    tmp8 = tmp6 >= tmp7
    tmp9 = tmp6 < tmp1
    tmp10 = tmp8 & tmp9
    tmp11 = tmp10 & tmp2
    tmp12 = 7.0
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp11, tmp12, tmp13)
    tmp15 = tmp2 & tmp2
    tmp16 = tmp6 < tmp7
    tmp17 = tmp16 & tmp15
    tmp18 = 6.0
    tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
    tmp20 = tl.where(tmp17, tmp18, tmp19)
    tmp21 = 0.0
    tmp22 = tl.where(tmp16, tmp20, tmp21)
    tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
    tmp24 = tl.where(tmp15, tmp22, tmp23)
    tmp25 = tmp0 >= tmp7
    tmp26 = tmp0 < tmp1
    tmp27 = tmp25 & tmp26
    tmp28 = tmp27 & tmp2
    tmp29 = tmp6 >= tmp1
    tmp30 = tmp29 & tmp28
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
    tmp35 = tl.where(tmp29, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp28, tmp35, tmp36)
    tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
    tmp39 = tl.where(tmp27, tmp37, tmp38)
    tmp40 = tl.where(tmp2, tmp24, tmp39)
    tmp41 = tl.where(tmp10, tmp14, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp2, tmp41, tmp42)
    tmp44 = tmp16 & tmp2
    tmp45 = tl.where(tmp44, tmp18, tmp19)
    tmp46 = tl.where(tmp16, tmp45, tmp21)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp2, tmp46, tmp47)
    tmp49 = tmp29 & tmp27
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
    tmp52 = tl.where(tmp29, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp27, tmp52, tmp53)
    tmp56 = tl.where(tmp27, tmp54, tmp55)
    tmp57 = tl.where(tmp2, tmp48, tmp56)
    tmp58 = tl.where(tmp2, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/7i/c7io46xdt3ptxo6ltcbckdtexmcbsso63trxe37lfrye2ltki4sg.py
# Source Nodes: [attn_6, attn_8, matmul_3], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_6 => add_15
# attn_8 => amax_1, div_3, exp_1, sub_6, sum_2
# matmul_3 => convert_element_type_43
triton_per_fused__softmax__to_copy_add_23 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[524288, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_23', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 298116
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 4
    x2 = (xindex // 196)
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp10 = tl.load(in_ptr3 + ((7*(x2 % 39)) + (273*(r3 // 7)) + (1911*(x2 // 39)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp11 = tl.load(in_ptr3 + ((7*(x2 % 39)) + (273*(x0 // 7)) + (1911*(x2 // 39)) + (x0 % 7)), xmask, eviction_policy='evict_last')
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp12 = tmp10 - tmp11
    tmp13 = 0.0
    tmp14 = tmp12 == tmp13
    tmp15 = tmp12 != tmp13
    tmp16 = -100.0
    tmp17 = tl.where(tmp15, tmp16, tmp12)
    tmp18 = tl.where(tmp14, tmp13, tmp17)
    tmp19 = tmp9 + tmp18
    tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
    tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
    tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
    tmp24 = tmp19 - tmp23
    tmp25 = tl_math.exp(tmp24)
    tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
    tmp28 = tl.where(rmask & xmask, tmp26, 0)
    tmp29 = tl.sum(tmp28, 1)[:, None]
    tmp30 = tmp25 / tmp29
    tmp31 = tmp30.to(tl.float32)
    tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/ag/caggjavokhkkhu3pmmyii7rbkrszu5jdeq7id5az4t423douybr2.py
# Source Nodes: [layer_norm_4, x_37, x_38], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_4 => add_19, add_20, mul_15, mul_16, rsqrt_4, sub_7, var_mean_4
# x_37 => add_18
# x_38 => convert_element_type_53
triton_red_fused__to_copy_add_native_layer_norm_24 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp4.to(tl.float32)
        tmp6 = tmp0 + tmp5
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)
        tmp12 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp15.to(tl.float32)
        tmp17 = tmp11 + tmp16
        tmp18 = tmp17 - tmp8
        tmp19 = 128.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr3 + (r1 + (128*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/lg/clgd3dty2q3ne3lnfcn3rmpdng74ev3dyoxwhsbz5berzpswiq3j.py
# Source Nodes: [x_37, x_43, x_out], Original ATen: [aten.add, aten.native_layer_norm]
# x_37 => add_18
# x_43 => add_22
# x_out => var_mean_6
triton_per_fused_add_native_layer_norm_25 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[131072, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_25', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 73984
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, other=0.0)
    tmp1 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp7 = tl.load(in_ptr3 + (r1 + (128*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp8 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp4.to(tl.float32)
    tmp6 = tmp0 + tmp5
    tmp9 = tmp8.to(tl.float32)
    tmp10 = tmp7 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tmp6 + tmp11
    tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
    tmp15 = tl.where(rmask & xmask, tmp13, 0)
    tmp16 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
    tmp18 = tl.where(rmask & xmask, tmp16, 0)
    tmp19 = tl.sum(tmp18, 1)[:, None]
    tmp20 = tl.full([XBLOCK, 1], 128, tl.int32)
    tmp21 = tmp20.to(tl.float32)
    tmp22 = tmp19 / tmp21
    tmp23 = tmp13 - tmp22
    tmp24 = tmp23 * tmp23
    tmp25 = tl.broadcast_to(tmp24, [XBLOCK, RBLOCK])
    tmp27 = tl.where(rmask & xmask, tmp25, 0)
    tmp28 = tl.sum(tmp27, 1)[:, None]
    tl.store(out_ptr0 + (r1 + (128*x0)), tmp12, rmask & xmask)
    tl.store(out_ptr1 + (x0), tmp22, xmask)
    tl.store(out_ptr2 + (x0), tmp28, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/2t/c2tajzagxv474224mjjurbovktceibfbfuzbpihyo7vcvrjy4ggt.py
# Source Nodes: [x_47, x_48], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_47 => add_23, add_24, mul_20, mul_21, rsqrt_5, sub_8, var_mean_5
# x_48 => convert_element_type_65
triton_red_fused__to_copy_native_layer_norm_26 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[32768, 512],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 18496
    rnumel = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp32_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp32_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp32_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1
        tmp1 = tl.full([1, 1], 0, tl.int64)
        tmp2 = tmp0 >= tmp1
        tmp3 = tl.full([1, 1], 128, tl.int64)
        tmp4 = tmp0 < tmp3
        tmp5 = tl.load(in_ptr0 + (r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0)
        tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
        tmp7 = tl.where(tmp4, tmp5, tmp6)
        tmp8 = tmp0 >= tmp3
        tmp9 = tl.full([1, 1], 256, tl.int64)
        tmp10 = tmp0 < tmp9
        tmp11 = tmp8 & tmp10
        tmp12 = tl.load(in_ptr0 + (34688 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp11 & xmask, eviction_policy='evict_last', other=0.0)
        tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
        tmp14 = tl.where(tmp11, tmp12, tmp13)
        tmp15 = tmp0 >= tmp9
        tmp16 = tl.full([1, 1], 384, tl.int64)
        tmp17 = tmp0 < tmp16
        tmp18 = tmp15 & tmp17
        tmp19 = tl.load(in_ptr0 + ((-128) + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp18 & xmask, eviction_policy='evict_last', other=0.0)
        tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
        tmp21 = tl.where(tmp18, tmp19, tmp20)
        tmp22 = tmp0 >= tmp16
        tmp23 = tl.full([1, 1], 512, tl.int64)
        tmp24 = tmp0 < tmp23
        tmp25 = tl.load(in_ptr0 + (34560 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp22 & xmask, eviction_policy='evict_last', other=0.0)
        tmp26 = tl.full(tmp25.shape, 0.0, tmp25.dtype)
        tmp27 = tl.where(tmp22, tmp25, tmp26)
        tmp28 = tl.where(tmp18, tmp21, tmp27)
        tmp29 = tl.where(tmp11, tmp14, tmp28)
        tmp30 = tl.where(tmp4, tmp7, tmp29)
        tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK])
        tmp32_mean_next, tmp32_m2_next, tmp32_weight_next = triton_helpers.welford_reduce(
            tmp31, tmp32_mean, tmp32_m2, tmp32_weight, roffset == 0
        )
        tmp32_mean = tl.where(rmask & xmask, tmp32_mean_next, tmp32_mean)
        tmp32_m2 = tl.where(rmask & xmask, tmp32_m2_next, tmp32_m2)
        tmp32_weight = tl.where(rmask & xmask, tmp32_weight_next, tmp32_weight)
    tmp32_tmp, tmp33_tmp, tmp34_tmp = triton_helpers.welford(
        tmp32_mean, tmp32_m2, tmp32_weight, 1
    )
    tmp32 = tmp32_tmp[:, None]
    tmp33 = tmp33_tmp[:, None]
    tmp34 = tmp34_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp73 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp75 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp35 = r1
        tmp36 = tl.full([1, 1], 0, tl.int64)
        tmp37 = tmp35 >= tmp36
        tmp38 = tl.full([1, 1], 128, tl.int64)
        tmp39 = tmp35 < tmp38
        tmp40 = tl.load(in_ptr0 + (r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp39 & xmask, eviction_policy='evict_first', other=0.0)
        tmp41 = tl.full(tmp40.shape, 0.0, tmp40.dtype)
        tmp42 = tl.where(tmp39, tmp40, tmp41)
        tmp43 = tmp35 >= tmp38
        tmp44 = tl.full([1, 1], 256, tl.int64)
        tmp45 = tmp35 < tmp44
        tmp46 = tmp43 & tmp45
        tmp47 = tl.load(in_ptr0 + (34688 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp46 & xmask, eviction_policy='evict_first', other=0.0)
        tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype)
        tmp49 = tl.where(tmp46, tmp47, tmp48)
        tmp50 = tmp35 >= tmp44
        tmp51 = tl.full([1, 1], 384, tl.int64)
        tmp52 = tmp35 < tmp51
        tmp53 = tmp50 & tmp52
        tmp54 = tl.load(in_ptr0 + ((-128) + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp53 & xmask, eviction_policy='evict_first', other=0.0)
        tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype)
        tmp56 = tl.where(tmp53, tmp54, tmp55)
        tmp57 = tmp35 >= tmp51
        tmp58 = tl.full([1, 1], 512, tl.int64)
        tmp59 = tmp35 < tmp58
        tmp60 = tl.load(in_ptr0 + (34560 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp57 & xmask, eviction_policy='evict_first', other=0.0)
        tmp61 = tl.full(tmp60.shape, 0.0, tmp60.dtype)
        tmp62 = tl.where(tmp57, tmp60, tmp61)
        tmp63 = tl.where(tmp53, tmp56, tmp62)
        tmp64 = tl.where(tmp46, tmp49, tmp63)
        tmp65 = tl.where(tmp39, tmp42, tmp64)
        tmp66 = tmp65 - tmp32
        tmp67 = 512.0
        tmp68 = tmp33 / tmp67
        tmp69 = 1e-05
        tmp70 = tmp68 + tmp69
        tmp71 = libdevice.rsqrt(tmp70)
        tmp72 = tmp66 * tmp71
        tmp74 = tmp72 * tmp73
        tmp76 = tmp74 + tmp75
        tmp77 = tmp76.to(tl.float32)
        tl.store(out_ptr3 + (r1 + (512*x0)), tmp77, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/gb/cgbhaobp4brzi4bqej3wkl6sz477wwvivithdd33dfu6fcxdbvus.py
# Source Nodes: [x_48], Original ATen: [aten._to_copy]
# x_48 => convert_element_type_64
triton_poi_fused__to_copy_27 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[131072], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_27', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 131072
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/5c/c5cuctfarzmazdidadennbrjjffmqgssmpavewem5xeljfxn3dam.py
# Source Nodes: [x_50], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_50 => convert_element_type_70, var_mean_7
triton_per_fused__to_copy_native_layer_norm_28 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[32768, 256],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_28', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 1, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 18496
    XBLOCK: tl.constexpr = 1
    rnumel = 256
    RBLOCK: tl.constexpr = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
    tmp4 = tl.where(rmask & xmask, tmp2, 0)
    tmp5 = tl.broadcast_to(tmp2, [RBLOCK])
    tmp7 = tl.where(rmask & xmask, tmp5, 0)
    tmp8 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
    tmp9 = tl.full([1], 256, tl.int32)
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tmp8 / tmp10
    tmp12 = tmp2 - tmp11
    tmp13 = tmp12 * tmp12
    tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
    tmp16 = tl.where(rmask & xmask, tmp14, 0)
    tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
    tl.store(out_ptr0 + (x0), tmp11, xmask)
    tl.store(out_ptr1 + (x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/eu/ceuvro67qgbkaa4ffhjqtznjwlcpwtbvlkoxtlnkt25syksoejbq.py
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
# linear_9 => convert_element_type_73
triton_poi_fused__to_copy_29 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_29', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 256) % 49
    x2 = (xindex // 12544)
    x0 = xindex % 256
    x3 = xindex
    tmp0 = (7*(x2 // 20)) + (x1 // 7)
    tmp1 = tl.full([1], 136, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (7*(x2 % 20)) + (x1 % 7)
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + (x0 + (256*(x1 % 7)) + (1792*(x2 % 20)) + (34816*(x1 // 7)) + (243712*(x2 // 20))), tmp5, other=0.0).to(tl.float32)
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tl.load(in_ptr1 + ((7*(x2 % 20)) + (136*(x1 // 7)) + (952*(x2 // 20)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp9 = tmp7 - tmp8
    tmp10 = tl.load(in_ptr2 + ((7*(x2 % 20)) + (136*(x1 // 7)) + (952*(x2 // 20)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp11 = 256.0
    tmp12 = tmp10 / tmp11
    tmp13 = 1e-05
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp9 * tmp15
    tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp18 = tmp16 * tmp17
    tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp20 = tmp18 + tmp19
    tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
    tmp22 = tl.where(tmp5, tmp20, tmp21)
    tmp23 = tmp22.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp23, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/m7/cm7zh34yglzig2ihdzmpad6sfpayr2nhod7fd3bo4flse5huvirg.py
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
# linear_9 => convert_element_type_72
triton_poi_fused__to_copy_30 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[262144], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_30', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 196608
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/zc/czcprbvfrl4zyuc66lgt4ai6jmptvsogizmygs2osms6op7tomn3.py
# Source Nodes: [attn_10, q_5], Original ATen: [aten.clone, aten.mul]
# attn_10 => clone_30
# q_5 => mul_28
triton_poi_fused_clone_mul_31 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_31', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 8
    x3 = (xindex // 12544)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (768*x1) + (37632*x3)), None).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = 0.1767766952966369
    tmp5 = tmp3 * tmp4
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/e2/ce2cfq53qjwse2okyyoak6avwxvkjxqftsjmnigoql77qkqcbrrs.py
# Source Nodes: [attn_10], Original ATen: [aten.clone]
# attn_10 => clone_31
triton_poi_fused_clone_32 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[131072, 64], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_32', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 102400
    xnumel = 49
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x3 = xindex
    y2 = (yindex // 256)
    y4 = yindex % 256
    y0 = yindex % 32
    y5 = (yindex // 32)
    tmp0 = tl.load(in_ptr0 + (256 + y4 + (768*x3) + (37632*y2)), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (256 + y4), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/zy/czybsxl6sxk5pc55zhlieseau6jqy7zrc273232l2wqlpaymolf5.py
# Source Nodes: [attn_11, attn_12, matmul_5], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_11 => add_29
# attn_12 => amax_2, div_6, exp_2, sub_12, sum_3
# matmul_5 => convert_element_type_79
triton_per_fused__softmax__to_copy_add_33 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[262144, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_33', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 156800
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 8
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (8*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
    tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
    tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
    tmp14 = tmp9 - tmp13
    tmp15 = tl_math.exp(tmp14)
    tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
    tmp18 = tl.where(rmask & xmask, tmp16, 0)
    tmp19 = tl.sum(tmp18, 1)[:, None]
    tmp20 = tmp15 / tmp19
    tmp21 = tmp20.to(tl.float32)
    tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/du/cdubf7lve33gjbq64z4vakpxvsuh4hfbnswhoxywkh4mxo4oybs2.py
# Source Nodes: [matmul_5], Original ATen: [aten.clone]
# matmul_5 => clone_34
triton_poi_fused_clone_34 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_34', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 8
    x3 = (xindex // 12544)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (512 + x0 + (32*x2) + (768*x1) + (37632*x3)), None).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (512 + x0 + (32*x2)), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/id/cid673u2q44fnref24kvgclfxy4xefbos3kmpw6yx3axmr4zh7pe.py
# Source Nodes: [x_54], Original ATen: [aten.clone]
# x_54 => clone_35
triton_poi_fused_clone_35 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_35', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 8
    x2 = (xindex // 256) % 49
    x3 = (xindex // 12544)
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (12544*x3)), None).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/gd/cgd5tsctovridqlbm4jovksrawiwqsk2hdwl7hi35gmu56kuxwy7.py
# Source Nodes: [layer_norm_8, x_61, x_62], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_8 => add_31, add_32, convert_element_type_87, mul_29, mul_30, rsqrt_8, sub_13, var_mean_8
# x_61 => add_30
# x_62 => convert_element_type_90
triton_red_fused__to_copy_add_native_layer_norm_36 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[32768, 256],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_36', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 18496
    rnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp0 + tmp4
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp11 + tmp15
        tmp17 = tmp16.to(tl.float32)
        tmp18 = tmp17 - tmp8
        tmp19 = 256.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr2 + (r1 + (256*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/vy/cvywxyrmptehw4ywu65ox3yzbfmlde2dw2jy3subo7aa44exaykq.py
# Source Nodes: [x_62], Original ATen: [aten._to_copy]
# x_62 => convert_element_type_89
triton_poi_fused__to_copy_37 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[262144], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_37', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 262144
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/nk/cnkppgvyewrogbxy2ewbwul5osjk3plsam5hnrnavfi2z6kblkyk.py
# Source Nodes: [x_63], Original ATen: [aten.gelu]
# x_63 => add_33, convert_element_type_94, convert_element_type_95, erf_2, mul_31, mul_32, mul_33
triton_poi_fused_gelu_38 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[33554432], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_38', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 18939904
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 1024
    tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = tmp3.to(tl.float32)
    tmp5 = 0.5
    tmp6 = tmp4 * tmp5
    tmp7 = 0.7071067811865476
    tmp8 = tmp4 * tmp7
    tmp9 = libdevice.erf(tmp8)
    tmp10 = 1.0
    tmp11 = tmp9 + tmp10
    tmp12 = tmp6 * tmp11
    tmp13 = tmp12.to(tl.float32)
    tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/mt/cmthogcbw6ytpaza2hq7tqd6jxon7w65cu5n467yareft33sxr42.py
# Source Nodes: [x_61, x_67, x_68], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_61 => add_30
# x_67 => add_34
# x_68 => convert_element_type_101, var_mean_9
triton_per_fused__to_copy_add_native_layer_norm_39 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[32768, 256],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_39', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 18496
    XBLOCK: tl.constexpr = 1
    rnumel = 256
    RBLOCK: tl.constexpr = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_out_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp0 + tmp4
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp6 + tmp8
    tmp10 = tmp5 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
    tmp14 = tl.where(rmask & xmask, tmp12, 0)
    tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
    tmp17 = tl.where(rmask & xmask, tmp15, 0)
    tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
    tmp19 = tl.full([1], 256, tl.int32)
    tmp20 = tmp19.to(tl.float32)
    tmp21 = tmp18 / tmp20
    tmp22 = tmp12 - tmp21
    tmp23 = tmp22 * tmp22
    tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp26 = tl.where(rmask & xmask, tmp24, 0)
    tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
    tl.store(in_out_ptr0 + (r1 + (256*x0)), tmp10, rmask & xmask)
    tl.store(out_ptr0 + (x0), tmp21, xmask)
    tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/nf/cnfaj3ajf5aqknow4tehtkbrxhtd7hg3eo5w2rldjul7cns6445j.py
# Source Nodes: [shifted_x_1, x_70], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x_1 => index_7, index_8
# x_70 => constant_pad_nd_3
triton_poi_fused_constant_pad_nd_roll_40 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_40', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = (xindex // 35840)
    x1 = (xindex // 256) % 140
    x0 = xindex % 256
    x4 = xindex
    tmp0 = (3 + x2) % 140
    tmp1 = tl.full([1], 136, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (3 + x1) % 140
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + (x0 + (256*((3 + x1) % 140)) + (34816*((3 + x2) % 140))), tmp5, other=0.0).to(tl.float32)
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tl.load(in_ptr1 + ((136*((3 + x2) % 140)) + ((3 + x1) % 140)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp9 = tmp7 - tmp8
    tmp10 = tl.load(in_ptr2 + ((136*((3 + x2) % 140)) + ((3 + x1) % 140)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp11 = 256.0
    tmp12 = tmp10 / tmp11
    tmp13 = 1e-05
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp9 * tmp15
    tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp18 = tmp16 * tmp17
    tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp20 = tmp18 + tmp19
    tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
    tmp22 = tl.where(tmp5, tmp20, tmp21)
    tl.store(out_ptr0 + (x4), tmp22, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/wt/cwthqbmye7k6z3qf2osgm2rmstbobfi2vu64ljr6xtk6lwp2yfij.py
# Source Nodes: [linear_13], Original ATen: [aten._to_copy]
# linear_13 => convert_element_type_104
triton_poi_fused__to_copy_41 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8388608], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_41', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 5017600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 256
    x1 = (xindex // 256) % 49
    x2 = (xindex // 12544)
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (256*(x1 % 7)) + (1792*(x2 % 20)) + (35840*(x1 // 7)) + (250880*(x2 // 20))), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/xb/cxbmto2fl747s6zog6a2j7e767nbetwcpmuu6g7lvjzgwbc622kq.py
# Source Nodes: [img_mask_1, setitem_10, setitem_11, setitem_12, setitem_13, setitem_9], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask_1 => full_1
# setitem_10 => copy_10, lift_fresh_copy_14
# setitem_11 => copy_11, lift_fresh_copy_15
# setitem_12 => copy_12, full_default_4, lift_fresh_copy_16
# setitem_13 => copy_13, lift_fresh_copy_17
# setitem_9 => copy_9, lift_fresh_copy_13
triton_poi_fused_fill_lift_fresh_slice_zeros_42 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[32768], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_42', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 19600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 140)
    x0 = xindex % 140
    x2 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 133, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 137, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tmp2 & tmp4
    tmp6 = x0
    tmp7 = tmp6 < tmp1
    tmp8 = tmp7 & tmp5
    tmp9 = 3.0
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp8, tmp9, tmp10)
    tmp12 = 0.0
    tmp13 = tl.where(tmp7, tmp11, tmp12)
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp5, tmp13, tmp14)
    tmp16 = tmp0 < tmp1
    tmp17 = tmp6 >= tmp3
    tmp18 = tmp17 & tmp16
    tmp19 = 2.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = tmp16 & tmp16
    tmp23 = tmp6 >= tmp1
    tmp24 = tmp6 < tmp3
    tmp25 = tmp23 & tmp24
    tmp26 = tmp25 & tmp22
    tmp27 = 1.0
    tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
    tmp29 = tl.where(tmp26, tmp27, tmp28)
    tmp30 = tmp16 & tmp22
    tmp31 = tmp7 & tmp30
    tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp33 = tl.where(tmp31, tmp12, tmp32)
    tmp34 = tl.where(tmp7, tmp33, tmp12)
    tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
    tmp36 = tl.where(tmp30, tmp34, tmp35)
    tmp37 = tl.where(tmp16, tmp36, tmp12)
    tmp38 = tl.where(tmp25, tmp29, tmp37)
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp22, tmp38, tmp39)
    tmp41 = tmp7 & tmp22
    tmp42 = tl.where(tmp41, tmp12, tmp32)
    tmp43 = tl.where(tmp7, tmp42, tmp12)
    tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
    tmp45 = tl.where(tmp22, tmp43, tmp44)
    tmp46 = tl.where(tmp16, tmp45, tmp12)
    tmp47 = tl.where(tmp16, tmp40, tmp46)
    tmp48 = tl.where(tmp17, tmp21, tmp47)
    tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
    tmp50 = tl.where(tmp16, tmp48, tmp49)
    tmp51 = tmp25 & tmp16
    tmp52 = tl.where(tmp51, tmp27, tmp28)
    tmp53 = tl.where(tmp25, tmp52, tmp46)
    tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
    tmp55 = tl.where(tmp16, tmp53, tmp54)
    tmp56 = tmp7 & tmp16
    tmp57 = tl.where(tmp56, tmp12, tmp32)
    tmp58 = tl.where(tmp7, tmp57, tmp12)
    tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
    tmp60 = tl.where(tmp16, tmp58, tmp59)
    tmp61 = tl.where(tmp16, tmp60, tmp12)
    tmp62 = tl.where(tmp16, tmp55, tmp61)
    tmp63 = tl.where(tmp16, tmp50, tmp62)
    tmp64 = tl.where(tmp5, tmp15, tmp63)
    tmp65 = tmp25 & tmp5
    tmp66 = 4.0
    tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
    tmp68 = tl.where(tmp65, tmp66, tmp67)
    tmp69 = tl.where(tmp25, tmp68, tmp64)
    tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
    tmp71 = tl.where(tmp5, tmp69, tmp70)
    tmp72 = tl.where(tmp5, tmp71, tmp64)
    tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/7z/c7ze75vhgc4kguv4mmbm47unz47zkwxnsdyj2jnp2sgkef4vzr75.py
# Source Nodes: [setitem_17], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_17 => copy_17, lift_fresh_copy_21
triton_poi_fused_fill_lift_fresh_43 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[512], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_43', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 420
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 140
    x1 = (xindex // 140)
    x2 = xindex
    tmp55 = tl.load(in_ptr0 + (19180 + x2), xmask)
    tmp0 = x0
    tmp1 = tl.full([1], 137, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = 8.0
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = 137 + x1
    tmp7 = tmp6 >= tmp1
    tmp8 = tl.full([1], 133, tl.int64)
    tmp9 = tmp0 >= tmp8
    tmp10 = tmp0 < tmp1
    tmp11 = tmp9 & tmp10
    tmp12 = tmp11 & tmp7
    tmp13 = 7.0
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp12, tmp13, tmp14)
    tmp16 = tmp7 & tmp7
    tmp17 = tmp0 < tmp8
    tmp18 = tmp17 & tmp16
    tmp19 = 6.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = 0.0
    tmp23 = tl.where(tmp17, tmp21, tmp22)
    tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
    tmp25 = tl.where(tmp16, tmp23, tmp24)
    tmp26 = tmp6 >= tmp8
    tmp27 = tmp6 < tmp1
    tmp28 = tmp26 & tmp27
    tmp29 = tmp28 & tmp7
    tmp30 = tmp2 & tmp29
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_ptr0 + (19180 + x2), tmp29 & xmask, other=0.0)
    tmp35 = tl.where(tmp2, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp29, tmp35, tmp36)
    tmp38 = tl.load(in_ptr0 + (19180 + x2), tmp7 & xmask, other=0.0)
    tmp39 = tl.where(tmp28, tmp37, tmp38)
    tmp40 = tl.where(tmp7, tmp25, tmp39)
    tmp41 = tl.where(tmp11, tmp15, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp7, tmp41, tmp42)
    tmp44 = tmp17 & tmp7
    tmp45 = tl.where(tmp44, tmp19, tmp20)
    tmp46 = tl.where(tmp17, tmp45, tmp22)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp7, tmp46, tmp47)
    tmp49 = tmp2 & tmp28
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_ptr0 + (19180 + x2), tmp28 & xmask, other=0.0)
    tmp52 = tl.where(tmp2, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp28, tmp52, tmp53)
    tmp56 = tl.where(tmp28, tmp54, tmp55)
    tmp57 = tl.where(tmp7, tmp48, tmp56)
    tmp58 = tl.where(tmp7, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/7z/c7zt444wgwxirjwbsyxeqljxxhoqggbgugx7kf4kefsaruhixphf.py
# Source Nodes: [setitem_14, setitem_15, setitem_16], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_14 => copy_14, lift_fresh_copy_18
# setitem_15 => copy_15, full_default_5, lift_fresh_copy_19
# setitem_16 => copy_16, lift_fresh_copy_20
triton_poi_fused_fill_lift_fresh_slice_44 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[32768], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_44', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 19600
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 140)
    x2 = xindex
    x0 = xindex % 140
    tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp0 = x1
    tmp1 = tl.full([1], 137, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.load(in_ptr0 + ((-19180) + x2), tmp2 & xmask, other=0.0)
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = x0
    tmp7 = tl.full([1], 133, tl.int64)
    tmp8 = tmp6 >= tmp7
    tmp9 = tmp6 < tmp1
    tmp10 = tmp8 & tmp9
    tmp11 = tmp10 & tmp2
    tmp12 = 7.0
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp11, tmp12, tmp13)
    tmp15 = tmp2 & tmp2
    tmp16 = tmp6 < tmp7
    tmp17 = tmp16 & tmp15
    tmp18 = 6.0
    tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
    tmp20 = tl.where(tmp17, tmp18, tmp19)
    tmp21 = 0.0
    tmp22 = tl.where(tmp16, tmp20, tmp21)
    tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
    tmp24 = tl.where(tmp15, tmp22, tmp23)
    tmp25 = tmp0 >= tmp7
    tmp26 = tmp0 < tmp1
    tmp27 = tmp25 & tmp26
    tmp28 = tmp27 & tmp2
    tmp29 = tmp6 >= tmp1
    tmp30 = tmp29 & tmp28
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
    tmp35 = tl.where(tmp29, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp28, tmp35, tmp36)
    tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
    tmp39 = tl.where(tmp27, tmp37, tmp38)
    tmp40 = tl.where(tmp2, tmp24, tmp39)
    tmp41 = tl.where(tmp10, tmp14, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp2, tmp41, tmp42)
    tmp44 = tmp16 & tmp2
    tmp45 = tl.where(tmp44, tmp18, tmp19)
    tmp46 = tl.where(tmp16, tmp45, tmp21)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp2, tmp46, tmp47)
    tmp49 = tmp29 & tmp27
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
    tmp52 = tl.where(tmp29, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp27, tmp52, tmp53)
    tmp56 = tl.where(tmp27, tmp54, tmp55)
    tmp57 = tl.where(tmp2, tmp48, tmp56)
    tmp58 = tl.where(tmp2, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/bh/cbhjqxz5ihwt4mxdvd3tbvekognxv56zu43p77uwnzktcpmbcvo7.py
# Source Nodes: [attn_16, attn_18, matmul_7], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_16 => add_40
# attn_18 => amax_3, div_7, exp_3, sub_15, sum_4
# matmul_7 => convert_element_type_110
triton_per_fused__softmax__to_copy_add_45 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[262144, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_45', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 156800
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 8
    x2 = (xindex // 392)
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp10 = tl.load(in_ptr3 + ((7*(x2 % 20)) + (140*(r3 // 7)) + (980*(x2 // 20)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp11 = tl.load(in_ptr3 + ((7*(x2 % 20)) + (140*(x0 // 7)) + (980*(x2 // 20)) + (x0 % 7)), xmask, eviction_policy='evict_last')
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (8*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp12 = tmp10 - tmp11
    tmp13 = 0.0
    tmp14 = tmp12 == tmp13
    tmp15 = tmp12 != tmp13
    tmp16 = -100.0
    tmp17 = tl.where(tmp15, tmp16, tmp12)
    tmp18 = tl.where(tmp14, tmp13, tmp17)
    tmp19 = tmp9 + tmp18
    tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
    tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
    tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
    tmp24 = tmp19 - tmp23
    tmp25 = tl_math.exp(tmp24)
    tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
    tmp28 = tl.where(rmask & xmask, tmp26, 0)
    tmp29 = tl.sum(tmp28, 1)[:, None]
    tmp30 = tmp25 / tmp29
    tmp31 = tmp30.to(tl.float32)
    tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/f2/cf2npb2ejacogxpb6cdcfuv4z5bmltrg2zqsllg7hjzccj546id4.py
# Source Nodes: [layer_norm_10, x_80, x_81], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_10 => add_44, add_45, convert_element_type_118, mul_37, mul_38, rsqrt_10, sub_16, var_mean_10
# x_80 => add_43
# x_81 => convert_element_type_121
triton_red_fused__to_copy_add_native_layer_norm_46 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[32768, 256],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_46', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 18496
    rnumel = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp0 + tmp4
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp11 + tmp15
        tmp17 = tmp16.to(tl.float32)
        tmp18 = tmp17 - tmp8
        tmp19 = 256.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr3 + (r1 + (256*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/qa/cqaaamnieqtt5cxfgvrjon2locvummefm6c6uw5vo5jw43ghlrb6.py
# Source Nodes: [x_80, x_86, x_out_1], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_80 => add_43
# x_86 => add_47
# x_out_1 => convert_element_type_137, var_mean_12
triton_per_fused__to_copy_add_native_layer_norm_47 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[32768, 256],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_47', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 18496
    XBLOCK: tl.constexpr = 1
    rnumel = 256
    RBLOCK: tl.constexpr = 256
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_out_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp0 + tmp4
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp6 + tmp8
    tmp10 = tmp5 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
    tmp14 = tl.where(rmask & xmask, tmp12, 0)
    tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
    tmp17 = tl.where(rmask & xmask, tmp15, 0)
    tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
    tmp19 = tl.full([1], 256, tl.int32)
    tmp20 = tmp19.to(tl.float32)
    tmp21 = tmp18 / tmp20
    tmp22 = tmp12 - tmp21
    tmp23 = tmp22 * tmp22
    tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp26 = tl.where(rmask & xmask, tmp24, 0)
    tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
    tl.store(in_out_ptr0 + (r1 + (256*x0)), tmp10, rmask & xmask)
    tl.store(out_ptr0 + (x0), tmp21, xmask)
    tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/vi/cvi4wbzakyffipye66rug6dumlcgc7kxs22zwyjq3yzsdisedy7v.py
# Source Nodes: [x_90, x_91], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_90 => add_48, add_49, convert_element_type_132, mul_42, mul_43, rsqrt_11, sub_17, var_mean_11
# x_91 => convert_element_type_134
triton_red_fused__to_copy_native_layer_norm_48 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[8192, 1024],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_48', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 4624
    rnumel = 1024
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp33_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp33_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp33_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1
        tmp1 = tl.full([1, 1], 0, tl.int64)
        tmp2 = tmp0 >= tmp1
        tmp3 = tl.full([1, 1], 256, tl.int64)
        tmp4 = tmp0 < tmp3
        tmp5 = tl.load(in_ptr0 + (r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
        tmp7 = tl.where(tmp4, tmp5, tmp6)
        tmp8 = tmp0 >= tmp3
        tmp9 = tl.full([1, 1], 512, tl.int64)
        tmp10 = tmp0 < tmp9
        tmp11 = tmp8 & tmp10
        tmp12 = tl.load(in_ptr0 + (34560 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp11 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
        tmp14 = tl.where(tmp11, tmp12, tmp13)
        tmp15 = tmp0 >= tmp9
        tmp16 = tl.full([1, 1], 768, tl.int64)
        tmp17 = tmp0 < tmp16
        tmp18 = tmp15 & tmp17
        tmp19 = tl.load(in_ptr0 + ((-256) + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
        tmp21 = tl.where(tmp18, tmp19, tmp20)
        tmp22 = tmp0 >= tmp16
        tmp23 = tl.full([1, 1], 1024, tl.int64)
        tmp24 = tmp0 < tmp23
        tmp25 = tl.load(in_ptr0 + (34304 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp22 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp26 = tl.full(tmp25.shape, 0.0, tmp25.dtype)
        tmp27 = tl.where(tmp22, tmp25, tmp26)
        tmp28 = tl.where(tmp18, tmp21, tmp27)
        tmp29 = tl.where(tmp11, tmp14, tmp28)
        tmp30 = tl.where(tmp4, tmp7, tmp29)
        tmp31 = tmp30.to(tl.float32)
        tmp32 = tl.broadcast_to(tmp31, [XBLOCK, RBLOCK])
        tmp33_mean_next, tmp33_m2_next, tmp33_weight_next = triton_helpers.welford_reduce(
            tmp32, tmp33_mean, tmp33_m2, tmp33_weight, roffset == 0
        )
        tmp33_mean = tl.where(rmask & xmask, tmp33_mean_next, tmp33_mean)
        tmp33_m2 = tl.where(rmask & xmask, tmp33_m2_next, tmp33_m2)
        tmp33_weight = tl.where(rmask & xmask, tmp33_weight_next, tmp33_weight)
    tmp33_tmp, tmp34_tmp, tmp35_tmp = triton_helpers.welford(
        tmp33_mean, tmp33_m2, tmp33_weight, 1
    )
    tmp33 = tmp33_tmp[:, None]
    tmp34 = tmp34_tmp[:, None]
    tmp35 = tmp35_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp75 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp77 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp36 = r1
        tmp37 = tl.full([1, 1], 0, tl.int64)
        tmp38 = tmp36 >= tmp37
        tmp39 = tl.full([1, 1], 256, tl.int64)
        tmp40 = tmp36 < tmp39
        tmp41 = tl.load(in_ptr0 + (r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp40 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
        tmp43 = tl.where(tmp40, tmp41, tmp42)
        tmp44 = tmp36 >= tmp39
        tmp45 = tl.full([1, 1], 512, tl.int64)
        tmp46 = tmp36 < tmp45
        tmp47 = tmp44 & tmp46
        tmp48 = tl.load(in_ptr0 + (34560 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp47 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
        tmp50 = tl.where(tmp47, tmp48, tmp49)
        tmp51 = tmp36 >= tmp45
        tmp52 = tl.full([1, 1], 768, tl.int64)
        tmp53 = tmp36 < tmp52
        tmp54 = tmp51 & tmp53
        tmp55 = tl.load(in_ptr0 + ((-256) + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp54 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp56 = tl.full(tmp55.shape, 0.0, tmp55.dtype)
        tmp57 = tl.where(tmp54, tmp55, tmp56)
        tmp58 = tmp36 >= tmp52
        tmp59 = tl.full([1, 1], 1024, tl.int64)
        tmp60 = tmp36 < tmp59
        tmp61 = tl.load(in_ptr0 + (34304 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp58 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp62 = tl.full(tmp61.shape, 0.0, tmp61.dtype)
        tmp63 = tl.where(tmp58, tmp61, tmp62)
        tmp64 = tl.where(tmp54, tmp57, tmp63)
        tmp65 = tl.where(tmp47, tmp50, tmp64)
        tmp66 = tl.where(tmp40, tmp43, tmp65)
        tmp67 = tmp66.to(tl.float32)
        tmp68 = tmp67 - tmp33
        tmp69 = 1024.0
        tmp70 = tmp34 / tmp69
        tmp71 = 1e-05
        tmp72 = tmp70 + tmp71
        tmp73 = libdevice.rsqrt(tmp72)
        tmp74 = tmp68 * tmp73
        tmp76 = tmp74 * tmp75
        tmp78 = tmp76 + tmp77
        tmp79 = tmp78.to(tl.float32)
        tl.store(out_ptr3 + (r1 + (1024*x0)), tmp79, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/2b/c2br45eimgqwhlwhtljpzxkxy33i7wtdmeqadpkf2nk76k6jqfxd.py
# Source Nodes: [x_91], Original ATen: [aten._to_copy]
# x_91 => convert_element_type_133
triton_poi_fused__to_copy_49 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[524288], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_49', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 524288
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/um/cumnariyjk3qtyk262n4xg2k375rwdvxooxdl5onswzasc5zkkpc.py
# Source Nodes: [x_93], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_93 => convert_element_type_140, var_mean_13
triton_per_fused__to_copy_native_layer_norm_50 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_50', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 1, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 4624
    XBLOCK: tl.constexpr = 1
    rnumel = 512
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tmp0.to(tl.float32)
    tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
    tmp4 = tl.where(rmask & xmask, tmp2, 0)
    tmp5 = tl.broadcast_to(tmp2, [RBLOCK])
    tmp7 = tl.where(rmask & xmask, tmp5, 0)
    tmp8 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
    tmp9 = tl.full([1], 512, tl.int32)
    tmp10 = tmp9.to(tl.float32)
    tmp11 = tmp8 / tmp10
    tmp12 = tmp2 - tmp11
    tmp13 = tmp12 * tmp12
    tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
    tmp16 = tl.where(rmask & xmask, tmp14, 0)
    tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
    tl.store(out_ptr0 + (x0), tmp11, xmask)
    tl.store(out_ptr1 + (x0), tmp17, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/g3/cg3khvtgpcqpjxr7l476vsoytoskgybzulj6ccz6csy5y4spmsvy.py
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
# linear_18 => convert_element_type_143
triton_poi_fused__to_copy_51 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_51', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 512) % 49
    x2 = (xindex // 25088)
    x0 = xindex % 512
    x3 = xindex
    tmp0 = (7*(x2 // 10)) + (x1 // 7)
    tmp1 = tl.full([1], 68, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (7*(x2 % 10)) + (x1 % 7)
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + (x0 + (512*(x1 % 7)) + (3584*(x2 % 10)) + (34816*(x1 // 7)) + (243712*(x2 // 10))), tmp5, other=0.0).to(tl.float32)
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tl.load(in_ptr1 + ((7*(x2 % 10)) + (68*(x1 // 7)) + (476*(x2 // 10)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp9 = tmp7 - tmp8
    tmp10 = tl.load(in_ptr2 + ((7*(x2 % 10)) + (68*(x1 // 7)) + (476*(x2 // 10)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp11 = 512.0
    tmp12 = tmp10 / tmp11
    tmp13 = 1e-05
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp9 * tmp15
    tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp18 = tmp16 * tmp17
    tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp20 = tmp18 + tmp19
    tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
    tmp22 = tl.where(tmp5, tmp20, tmp21)
    tmp23 = tmp22.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp23, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/7h/c7h2rh5tddfvhz4ra3gissuywjohe7jgaicl73dtmurv236yclh4.py
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
# linear_18 => convert_element_type_142
triton_poi_fused__to_copy_52 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[1048576], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_52', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 786432
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/lj/cljyipqxaos2lyhrcrywearlyzfxb2hz6o2kxzgjx4qc6pmpeudj.py
# Source Nodes: [attn_20, q_9], Original ATen: [aten.clone, aten.mul]
# attn_20 => clone_56
# q_9 => mul_50
triton_poi_fused_clone_mul_53 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_53', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 16
    x3 = (xindex // 25088)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1536*x1) + (75264*x3)), None).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = 0.1767766952966369
    tmp5 = tmp3 * tmp4
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/av/cavjexsophvnr2bav6uoax3inj3rzrgvwxxuzbozsbhi5b7t7wxk.py
# Source Nodes: [attn_20], Original ATen: [aten.clone]
# attn_20 => clone_57
triton_poi_fused_clone_54 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[65536, 64], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_54', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 51200
    xnumel = 49
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x3 = xindex
    y2 = (yindex // 512)
    y4 = yindex % 512
    y0 = yindex % 32
    y5 = (yindex // 32)
    tmp0 = tl.load(in_ptr0 + (512 + y4 + (1536*x3) + (75264*y2)), xmask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (512 + y4), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/2m/c2mbsjkqnezozcymlpyhsrwbu3d475l7jp7llp3av6ycshavi5cp.py
# Source Nodes: [attn_21, attn_22, matmul_9], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_21 => add_54
# attn_22 => amax_4, div_10, exp_4, sub_21, sum_5
# matmul_9 => convert_element_type_149
triton_per_fused__softmax__to_copy_add_55 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[131072, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_55', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 78400
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 16
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (16*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
    tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
    tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
    tmp14 = tmp9 - tmp13
    tmp15 = tl_math.exp(tmp14)
    tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
    tmp18 = tl.where(rmask & xmask, tmp16, 0)
    tmp19 = tl.sum(tmp18, 1)[:, None]
    tmp20 = tmp15 / tmp19
    tmp21 = tmp20.to(tl.float32)
    tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/lk/clkuywa7pywq2gszyicfrwhbcan7mq2hjz7jsjx7fbjhf2kv6ywf.py
# Source Nodes: [matmul_9], Original ATen: [aten.clone]
# matmul_9 => clone_60
triton_poi_fused_clone_56 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_56', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 49
    x2 = (xindex // 1568) % 16
    x3 = (xindex // 25088)
    x4 = xindex % 1568
    x5 = (xindex // 1568)
    tmp0 = tl.load(in_ptr0 + (1024 + x0 + (32*x2) + (1536*x1) + (75264*x3)), None).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (1024 + x0 + (32*x2)), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/dg/cdg5noei5zaw6qlukk4dp5ch354ru3lymtetv3ljqnq626zlqxdh.py
# Source Nodes: [x_97], Original ATen: [aten.clone]
# x_97 => clone_61
triton_poi_fused_clone_57 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_57', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 32
    x1 = (xindex // 32) % 16
    x2 = (xindex // 512) % 49
    x3 = (xindex // 25088)
    x4 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (25088*x3)), None).to(tl.float32)
    tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/t2/ct2dp3mhidfvqpjtzrn2dhky572jcvj66o5hnov3dfawuo5qorjc.py
# Source Nodes: [layer_norm_14, x_104, x_105], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_14 => add_56, add_57, convert_element_type_157, mul_51, mul_52, rsqrt_14, sub_22, var_mean_14
# x_104 => add_55
# x_105 => convert_element_type_160
triton_red_fused__to_copy_add_native_layer_norm_58 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_58', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 4624
    rnumel = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp0 + tmp4
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp11 + tmp15
        tmp17 = tmp16.to(tl.float32)
        tmp18 = tmp17 - tmp8
        tmp19 = 512.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr2 + (r1 + (512*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/np/cnp6uoxnm72l4yszb432o46mj4nxiei7ycexyxlr2x6ni274sotg.py
# Source Nodes: [x_105], Original ATen: [aten._to_copy]
# x_105 => convert_element_type_159
triton_poi_fused__to_copy_59 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[1048576], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_59', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 1048576
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/nq/cnqdp7jcikbji2dqz7a5oy3qn777ber3u2a6msqynur5lblmly63.py
# Source Nodes: [x_106], Original ATen: [aten.gelu]
# x_106 => add_58, convert_element_type_164, convert_element_type_165, erf_4, mul_53, mul_54, mul_55
triton_poi_fused_gelu_60 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[16777216], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_60', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 9469952
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = xindex
    x0 = xindex % 2048
    tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
    tmp2 = tmp1.to(tl.float32)
    tmp3 = tmp0 + tmp2
    tmp4 = tmp3.to(tl.float32)
    tmp5 = 0.5
    tmp6 = tmp4 * tmp5
    tmp7 = 0.7071067811865476
    tmp8 = tmp4 * tmp7
    tmp9 = libdevice.erf(tmp8)
    tmp10 = 1.0
    tmp11 = tmp9 + tmp10
    tmp12 = tmp6 * tmp11
    tmp13 = tmp12.to(tl.float32)
    tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/qh/cqhfnjwnhzm6zoxn4h6l4ih7jzgbrmhwlxlccqxcasplc2pj2oio.py
# Source Nodes: [x_104, x_110, x_111], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_104 => add_55
# x_110 => add_59
# x_111 => convert_element_type_171, var_mean_15
triton_per_fused__to_copy_add_native_layer_norm_61 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_61', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 4624
    XBLOCK: tl.constexpr = 1
    rnumel = 512
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_out_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp0 + tmp4
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp6 + tmp8
    tmp10 = tmp5 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
    tmp14 = tl.where(rmask & xmask, tmp12, 0)
    tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
    tmp17 = tl.where(rmask & xmask, tmp15, 0)
    tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
    tmp19 = tl.full([1], 512, tl.int32)
    tmp20 = tmp19.to(tl.float32)
    tmp21 = tmp18 / tmp20
    tmp22 = tmp12 - tmp21
    tmp23 = tmp22 * tmp22
    tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp26 = tl.where(rmask & xmask, tmp24, 0)
    tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
    tl.store(in_out_ptr0 + (r1 + (512*x0)), tmp10, rmask & xmask)
    tl.store(out_ptr0 + (x0), tmp21, xmask)
    tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/d6/cd647osl6hgyj5l4klpv44ixju2cu5kkyigf4txwmejkih3db2aj.py
# Source Nodes: [shifted_x_2, x_113], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x_2 => index_13, index_14
# x_113 => constant_pad_nd_5
triton_poi_fused_constant_pad_nd_roll_62 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_62', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x2 = (xindex // 35840)
    x1 = (xindex // 512) % 70
    x0 = xindex % 512
    x4 = xindex
    tmp0 = (3 + x2) % 70
    tmp1 = tl.full([1], 68, tl.int64)
    tmp2 = tmp0 < tmp1
    tmp3 = (3 + x1) % 70
    tmp4 = tmp3 < tmp1
    tmp5 = tmp2 & tmp4
    tmp6 = tl.load(in_ptr0 + (x0 + (512*((3 + x1) % 70)) + (34816*((3 + x2) % 70))), tmp5, other=0.0).to(tl.float32)
    tmp7 = tmp6.to(tl.float32)
    tmp8 = tl.load(in_ptr1 + ((68*((3 + x2) % 70)) + ((3 + x1) % 70)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp9 = tmp7 - tmp8
    tmp10 = tl.load(in_ptr2 + ((68*((3 + x2) % 70)) + ((3 + x1) % 70)), tmp5, eviction_policy='evict_last', other=0.0)
    tmp11 = 512.0
    tmp12 = tmp10 / tmp11
    tmp13 = 1e-05
    tmp14 = tmp12 + tmp13
    tmp15 = libdevice.rsqrt(tmp14)
    tmp16 = tmp9 * tmp15
    tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp18 = tmp16 * tmp17
    tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
    tmp20 = tmp18 + tmp19
    tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
    tmp22 = tl.where(tmp5, tmp20, tmp21)
    tl.store(out_ptr0 + (x4), tmp22, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/kf/ckfhysrliuileref3yojypgpmuck4hqbsxqymr3ij7j54y6ysdwm.py
# Source Nodes: [linear_22], Original ATen: [aten._to_copy]
# linear_22 => convert_element_type_174
triton_poi_fused__to_copy_63 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[4194304], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_63', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 2508800
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 512
    x1 = (xindex // 512) % 49
    x2 = (xindex // 25088)
    x3 = xindex
    tmp0 = tl.load(in_ptr0 + (x0 + (512*(x1 % 7)) + (3584*(x2 % 10)) + (35840*(x1 // 7)) + (250880*(x2 // 10))), None)
    tmp1 = tmp0.to(tl.float32)
    tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/la/cla235hx3mubj6sgpoukkyfixg6bxgnvd6ngdgy3zwatq5lyqmcp.py
# Source Nodes: [img_mask_2, setitem_18, setitem_19, setitem_20, setitem_21, setitem_22], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask_2 => full_2
# setitem_18 => copy_18, lift_fresh_copy_24
# setitem_19 => copy_19, lift_fresh_copy_25
# setitem_20 => copy_20, lift_fresh_copy_26
# setitem_21 => copy_21, full_default_8, lift_fresh_copy_27
# setitem_22 => copy_22, lift_fresh_copy_28
triton_poi_fused_fill_lift_fresh_slice_zeros_64 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8192], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_64', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4900
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 70)
    x0 = xindex % 70
    x2 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 63, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 67, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tmp2 & tmp4
    tmp6 = x0
    tmp7 = tmp6 < tmp1
    tmp8 = tmp7 & tmp5
    tmp9 = 3.0
    tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
    tmp11 = tl.where(tmp8, tmp9, tmp10)
    tmp12 = 0.0
    tmp13 = tl.where(tmp7, tmp11, tmp12)
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp5, tmp13, tmp14)
    tmp16 = tmp0 < tmp1
    tmp17 = tmp6 >= tmp3
    tmp18 = tmp17 & tmp16
    tmp19 = 2.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = tmp16 & tmp16
    tmp23 = tmp6 >= tmp1
    tmp24 = tmp6 < tmp3
    tmp25 = tmp23 & tmp24
    tmp26 = tmp25 & tmp22
    tmp27 = 1.0
    tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
    tmp29 = tl.where(tmp26, tmp27, tmp28)
    tmp30 = tmp16 & tmp22
    tmp31 = tmp7 & tmp30
    tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp33 = tl.where(tmp31, tmp12, tmp32)
    tmp34 = tl.where(tmp7, tmp33, tmp12)
    tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
    tmp36 = tl.where(tmp30, tmp34, tmp35)
    tmp37 = tl.where(tmp16, tmp36, tmp12)
    tmp38 = tl.where(tmp25, tmp29, tmp37)
    tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
    tmp40 = tl.where(tmp22, tmp38, tmp39)
    tmp41 = tmp7 & tmp22
    tmp42 = tl.where(tmp41, tmp12, tmp32)
    tmp43 = tl.where(tmp7, tmp42, tmp12)
    tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
    tmp45 = tl.where(tmp22, tmp43, tmp44)
    tmp46 = tl.where(tmp16, tmp45, tmp12)
    tmp47 = tl.where(tmp16, tmp40, tmp46)
    tmp48 = tl.where(tmp17, tmp21, tmp47)
    tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
    tmp50 = tl.where(tmp16, tmp48, tmp49)
    tmp51 = tmp25 & tmp16
    tmp52 = tl.where(tmp51, tmp27, tmp28)
    tmp53 = tl.where(tmp25, tmp52, tmp46)
    tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
    tmp55 = tl.where(tmp16, tmp53, tmp54)
    tmp56 = tmp7 & tmp16
    tmp57 = tl.where(tmp56, tmp12, tmp32)
    tmp58 = tl.where(tmp7, tmp57, tmp12)
    tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
    tmp60 = tl.where(tmp16, tmp58, tmp59)
    tmp61 = tl.where(tmp16, tmp60, tmp12)
    tmp62 = tl.where(tmp16, tmp55, tmp61)
    tmp63 = tl.where(tmp16, tmp50, tmp62)
    tmp64 = tl.where(tmp5, tmp15, tmp63)
    tmp65 = tmp25 & tmp5
    tmp66 = 4.0
    tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
    tmp68 = tl.where(tmp65, tmp66, tmp67)
    tmp69 = tl.where(tmp25, tmp68, tmp64)
    tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
    tmp71 = tl.where(tmp5, tmp69, tmp70)
    tmp72 = tl.where(tmp5, tmp71, tmp64)
    tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/iy/ciyf4skxx4pkmks5oyxbr7ildelf2g5isnbw6b6bh6z7lglqwl7j.py
# Source Nodes: [setitem_26], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_26 => copy_26, lift_fresh_copy_32
triton_poi_fused_fill_lift_fresh_65 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[256], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_65', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 210
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex % 70
    x1 = (xindex // 70)
    x2 = xindex
    tmp55 = tl.load(in_ptr0 + (4690 + x2), xmask)
    tmp0 = x0
    tmp1 = tl.full([1], 67, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = 8.0
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = 67 + x1
    tmp7 = tmp6 >= tmp1
    tmp8 = tl.full([1], 63, tl.int64)
    tmp9 = tmp0 >= tmp8
    tmp10 = tmp0 < tmp1
    tmp11 = tmp9 & tmp10
    tmp12 = tmp11 & tmp7
    tmp13 = 7.0
    tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
    tmp15 = tl.where(tmp12, tmp13, tmp14)
    tmp16 = tmp7 & tmp7
    tmp17 = tmp0 < tmp8
    tmp18 = tmp17 & tmp16
    tmp19 = 6.0
    tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
    tmp21 = tl.where(tmp18, tmp19, tmp20)
    tmp22 = 0.0
    tmp23 = tl.where(tmp17, tmp21, tmp22)
    tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
    tmp25 = tl.where(tmp16, tmp23, tmp24)
    tmp26 = tmp6 >= tmp8
    tmp27 = tmp6 < tmp1
    tmp28 = tmp26 & tmp27
    tmp29 = tmp28 & tmp7
    tmp30 = tmp2 & tmp29
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_ptr0 + (4690 + x2), tmp29 & xmask, other=0.0)
    tmp35 = tl.where(tmp2, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp29, tmp35, tmp36)
    tmp38 = tl.load(in_ptr0 + (4690 + x2), tmp7 & xmask, other=0.0)
    tmp39 = tl.where(tmp28, tmp37, tmp38)
    tmp40 = tl.where(tmp7, tmp25, tmp39)
    tmp41 = tl.where(tmp11, tmp15, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp7, tmp41, tmp42)
    tmp44 = tmp17 & tmp7
    tmp45 = tl.where(tmp44, tmp19, tmp20)
    tmp46 = tl.where(tmp17, tmp45, tmp22)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp7, tmp46, tmp47)
    tmp49 = tmp2 & tmp28
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_ptr0 + (4690 + x2), tmp28 & xmask, other=0.0)
    tmp52 = tl.where(tmp2, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp28, tmp52, tmp53)
    tmp56 = tl.where(tmp28, tmp54, tmp55)
    tmp57 = tl.where(tmp7, tmp48, tmp56)
    tmp58 = tl.where(tmp7, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/wo/cwo26dzkhl3aq3ehh4clegihnjhiug4uevjligspazwbynwcz3bd.py
# Source Nodes: [setitem_23, setitem_24, setitem_25], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_23 => copy_23, lift_fresh_copy_29
# setitem_24 => copy_24, full_default_9, lift_fresh_copy_30
# setitem_25 => copy_25, lift_fresh_copy_31
triton_poi_fused_fill_lift_fresh_slice_66 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[8192], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_66', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 4900
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x1 = (xindex // 70)
    x2 = xindex
    x0 = xindex % 70
    tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
    tmp0 = x1
    tmp1 = tl.full([1], 67, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.load(in_ptr0 + ((-4690) + x2), tmp2 & xmask, other=0.0)
    tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
    tmp5 = tl.where(tmp2, tmp3, tmp4)
    tmp6 = x0
    tmp7 = tl.full([1], 63, tl.int64)
    tmp8 = tmp6 >= tmp7
    tmp9 = tmp6 < tmp1
    tmp10 = tmp8 & tmp9
    tmp11 = tmp10 & tmp2
    tmp12 = 7.0
    tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
    tmp14 = tl.where(tmp11, tmp12, tmp13)
    tmp15 = tmp2 & tmp2
    tmp16 = tmp6 < tmp7
    tmp17 = tmp16 & tmp15
    tmp18 = 6.0
    tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
    tmp20 = tl.where(tmp17, tmp18, tmp19)
    tmp21 = 0.0
    tmp22 = tl.where(tmp16, tmp20, tmp21)
    tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
    tmp24 = tl.where(tmp15, tmp22, tmp23)
    tmp25 = tmp0 >= tmp7
    tmp26 = tmp0 < tmp1
    tmp27 = tmp25 & tmp26
    tmp28 = tmp27 & tmp2
    tmp29 = tmp6 >= tmp1
    tmp30 = tmp29 & tmp28
    tmp31 = 5.0
    tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
    tmp33 = tl.where(tmp30, tmp31, tmp32)
    tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
    tmp35 = tl.where(tmp29, tmp33, tmp34)
    tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
    tmp37 = tl.where(tmp28, tmp35, tmp36)
    tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
    tmp39 = tl.where(tmp27, tmp37, tmp38)
    tmp40 = tl.where(tmp2, tmp24, tmp39)
    tmp41 = tl.where(tmp10, tmp14, tmp40)
    tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
    tmp43 = tl.where(tmp2, tmp41, tmp42)
    tmp44 = tmp16 & tmp2
    tmp45 = tl.where(tmp44, tmp18, tmp19)
    tmp46 = tl.where(tmp16, tmp45, tmp21)
    tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
    tmp48 = tl.where(tmp2, tmp46, tmp47)
    tmp49 = tmp29 & tmp27
    tmp50 = tl.where(tmp49, tmp31, tmp32)
    tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
    tmp52 = tl.where(tmp29, tmp50, tmp51)
    tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
    tmp54 = tl.where(tmp27, tmp52, tmp53)
    tmp56 = tl.where(tmp27, tmp54, tmp55)
    tmp57 = tl.where(tmp2, tmp48, tmp56)
    tmp58 = tl.where(tmp2, tmp43, tmp57)
    tmp59 = tl.where(tmp2, tmp5, tmp58)
    tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/6z/c6zbanuvwcqjxdnkk3jjviuhk6fvpanrlsj4nzjsrldlmzsgyaei.py
# Source Nodes: [attn_26, attn_28, matmul_11], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_26 => add_65
# attn_28 => amax_5, div_11, exp_5, sub_24, sum_6
# matmul_11 => convert_element_type_180
triton_per_fused__softmax__to_copy_add_67 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[131072, 64],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_67', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 78400
    rnumel = 49
    RBLOCK: tl.constexpr = 64
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = rindex < rnumel
    r3 = rindex
    x4 = xindex
    x0 = xindex % 49
    x1 = (xindex // 49) % 16
    x2 = (xindex // 784)
    x5 = (xindex // 49)
    tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp10 = tl.load(in_ptr3 + ((7*(x2 % 10)) + (70*(r3 // 7)) + (490*(x2 // 10)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
    tmp11 = tl.load(in_ptr3 + ((7*(x2 % 10)) + (70*(x0 // 7)) + (490*(x2 // 10)) + (x0 % 7)), xmask, eviction_policy='evict_last')
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
    tmp4 = tmp2 + tmp3
    tmp5 = tmp2 < 0
    tmp6 = tl.where(tmp5, tmp4, tmp2)
    tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
    tmp8 = tl.load(in_ptr2 + (x1 + (16*tmp6)), rmask & xmask, eviction_policy='evict_last')
    tmp9 = tmp1 + tmp8
    tmp12 = tmp10 - tmp11
    tmp13 = 0.0
    tmp14 = tmp12 == tmp13
    tmp15 = tmp12 != tmp13
    tmp16 = -100.0
    tmp17 = tl.where(tmp15, tmp16, tmp12)
    tmp18 = tl.where(tmp14, tmp13, tmp17)
    tmp19 = tmp9 + tmp18
    tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
    tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
    tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
    tmp24 = tmp19 - tmp23
    tmp25 = tl_math.exp(tmp24)
    tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
    tmp28 = tl.where(rmask & xmask, tmp26, 0)
    tmp29 = tl.sum(tmp28, 1)[:, None]
    tmp30 = tmp25 / tmp29
    tmp31 = tmp30.to(tl.float32)
    tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/6x/c6xasej3mmacmhpbp2ayc2wmpde5tmj3algi7vysqjsht3g5bcoy.py
# Source Nodes: [layer_norm_16, x_123, x_124], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_16 => add_69, add_70, convert_element_type_188, mul_59, mul_60, rsqrt_16, sub_25, var_mean_16
# x_123 => add_68
# x_124 => convert_element_type_191
triton_red_fused__to_copy_add_native_layer_norm_68 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_68', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
    xnumel = 4624
    rnumel = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
        tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp3 = tmp2.to(tl.float32)
        tmp4 = tmp1 + tmp3
        tmp5 = tmp0 + tmp4
        tmp6 = tmp5.to(tl.float32)
        tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
        tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
            tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
        )
        tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
        tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
        tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
    tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
        tmp8_mean, tmp8_m2, tmp8_weight, 1
    )
    tmp8 = tmp8_tmp[:, None]
    tmp9 = tmp9_tmp[:, None]
    tmp10 = tmp10_tmp[:, None]
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp11 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp12 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
        tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
        tmp14 = tmp13.to(tl.float32)
        tmp15 = tmp12 + tmp14
        tmp16 = tmp11 + tmp15
        tmp17 = tmp16.to(tl.float32)
        tmp18 = tmp17 - tmp8
        tmp19 = 512.0
        tmp20 = tmp9 / tmp19
        tmp21 = 1e-05
        tmp22 = tmp20 + tmp21
        tmp23 = libdevice.rsqrt(tmp22)
        tmp24 = tmp18 * tmp23
        tmp26 = tmp24 * tmp25
        tmp28 = tmp26 + tmp27
        tmp29 = tmp28.to(tl.float32)
        tl.store(out_ptr3 + (r1 + (512*x0)), tmp29, rmask & xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/rv/crvoqdpaf6qnpdaldpzwrrgi7gxvyresscwiboxzh75r7txhfnnf.py
# Source Nodes: [x_123, x_129, x_130], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_123 => add_68
# x_129 => add_72
# x_130 => convert_element_type_202, var_mean_17
triton_per_fused__to_copy_add_native_layer_norm_69 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_69', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
    xnumel = 4624
    XBLOCK: tl.constexpr = 1
    rnumel = 512
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_out_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp0 + tmp4
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp6 + tmp8
    tmp10 = tmp5 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
    tmp14 = tl.where(rmask & xmask, tmp12, 0)
    tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
    tmp17 = tl.where(rmask & xmask, tmp15, 0)
    tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
    tmp19 = tl.full([1], 512, tl.int32)
    tmp20 = tmp19.to(tl.float32)
    tmp21 = tmp18 / tmp20
    tmp22 = tmp12 - tmp21
    tmp23 = tmp22 * tmp22
    tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp26 = tl.where(rmask & xmask, tmp24, 0)
    tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
    tl.store(in_out_ptr0 + (r1 + (512*x0)), tmp10, rmask & xmask)
    tl.store(out_ptr0 + (x0), tmp21, xmask)
    tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/i3/ci3pm4qxa5bdompouf64ojpj35sx6k5ehnbqcwoemqypakfg6p26.py
# Source Nodes: [x_419, x_425, x_out_2], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_419 => add_236
# x_425 => add_240
# x_out_2 => convert_element_type_698, var_mean_49
triton_per_fused__to_copy_add_native_layer_norm_70 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints=[8192, 512],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_70', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel):
    xnumel = 4624
    XBLOCK: tl.constexpr = 1
    rnumel = 512
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    roffset = 0
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, other=0.0).to(tl.float32)
    tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp6 = tl.load(in_ptr3 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
    tmp7 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
    tmp3 = tmp2.to(tl.float32)
    tmp4 = tmp1 + tmp3
    tmp5 = tmp0 + tmp4
    tmp8 = tmp7.to(tl.float32)
    tmp9 = tmp6 + tmp8
    tmp10 = tmp5 + tmp9
    tmp11 = tmp10.to(tl.float32)
    tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
    tmp14 = tl.where(rmask & xmask, tmp12, 0)
    tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
    tmp17 = tl.where(rmask & xmask, tmp15, 0)
    tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
    tmp19 = tl.full([1], 512, tl.int32)
    tmp20 = tmp19.to(tl.float32)
    tmp21 = tmp18 / tmp20
    tmp22 = tmp12 - tmp21
    tmp23 = tmp22 * tmp22
    tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
    tmp26 = tl.where(rmask & xmask, tmp24, 0)
    tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
    tl.store(out_ptr0 + (r1 + (512*x0)), tmp11, rmask & xmask)
    tl.store(out_ptr1 + (x0), tmp21, xmask)
    tl.store(out_ptr2 + (x0), tmp27, xmask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/uz/cuzwnuuhvp4xvn2ygexj3suttmbl2cv7nhkqtwzvahp3hcuf5ncx.py
# Source Nodes: [out], Original ATen: [aten.clone]
# out => clone_27
triton_poi_fused_clone_71 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[128, 131072], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_71', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 128
    xnumel = 73984
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + (128*x1)), xmask & ymask, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
    tmp3 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
    tmp10 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
    tmp12 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
    tmp2 = tmp0 - tmp1
    tmp4 = 128.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = libdevice.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tl.store(out_ptr0 + (x1 + (73984*y0)), tmp13, xmask & ymask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/5s/c5s7m6ozgfpwmvadtkkx5yxo36i4abkudvpqhbmoam4ymufqrcsm.py
# Source Nodes: [out_1], Original ATen: [aten.clone]
# out_1 => clone_53
triton_poi_fused_clone_72 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[256, 32768], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_72', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 256
    xnumel = 18496
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + (256*x1)), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
    tmp4 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
    tmp11 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
    tmp13 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp1 - tmp2
    tmp5 = 256.0
    tmp6 = tmp4 / tmp5
    tmp7 = 1e-05
    tmp8 = tmp6 + tmp7
    tmp9 = libdevice.rsqrt(tmp8)
    tmp10 = tmp3 * tmp9
    tmp12 = tmp10 * tmp11
    tmp14 = tmp12 + tmp13
    tl.store(out_ptr0 + (x1 + (18496*y0)), tmp14, xmask & ymask)
''', device_str='cuda')


# kernel path: /tmp/torchinductor_root/tc/ctcwm3wctt5jaq4f7ptheotzpz5uo4d7jkikv5w3fzgbb6we22en.py
# Source Nodes: [out_2], Original ATen: [aten.clone]
# out_2 => clone_271
triton_poi_fused_clone_73 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[512, 8192], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_73', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 512
    xnumel = 4624
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    x1 = xindex
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (y0 + (512*x1)), xmask & ymask, eviction_policy='evict_last')
    tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
    tmp3 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
    tmp10 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
    tmp12 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
    tmp2 = tmp0 - tmp1
    tmp4 = 512.0
    tmp5 = tmp3 / tmp4
    tmp6 = 1e-05
    tmp7 = tmp5 + tmp6
    tmp8 = libdevice.rsqrt(tmp7)
    tmp9 = tmp2 * tmp8
    tmp11 = tmp9 * tmp10
    tmp13 = tmp11 + tmp12
    tl.store(out_ptr0 + (x1 + (4624*y0)), tmp13, xmask & ymask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1 = args
    args.clear()
    assert_size_stride(arg0_1, (1, 3, 1088, 1088), (3551232, 1183744, 1088, 1))
    assert_size_stride(arg1_1, (128, 3, 4, 4), (48, 16, 4, 1))
    assert_size_stride(arg2_1, (128, ), (1, ))
    assert_size_stride(arg3_1, (128, ), (1, ))
    assert_size_stride(arg4_1, (128, ), (1, ))
    assert_size_stride(arg5_1, (128, ), (1, ))
    assert_size_stride(arg6_1, (128, ), (1, ))
    assert_size_stride(arg7_1, (384, 128), (128, 1))
    assert_size_stride(arg8_1, (384, ), (1, ))
    assert_size_stride(arg9_1, (169, 4), (4, 1))
    assert_size_stride(arg10_1, (49, 49), (49, 1))
    assert_size_stride(arg11_1, (128, 128), (128, 1))
    assert_size_stride(arg12_1, (128, ), (1, ))
    assert_size_stride(arg13_1, (128, ), (1, ))
    assert_size_stride(arg14_1, (128, ), (1, ))
    assert_size_stride(arg15_1, (512, 128), (128, 1))
    assert_size_stride(arg16_1, (512, ), (1, ))
    assert_size_stride(arg17_1, (128, 512), (512, 1))
    assert_size_stride(arg18_1, (128, ), (1, ))
    assert_size_stride(arg19_1, (128, ), (1, ))
    assert_size_stride(arg20_1, (128, ), (1, ))
    assert_size_stride(arg21_1, (384, 128), (128, 1))
    assert_size_stride(arg22_1, (384, ), (1, ))
    assert_size_stride(arg23_1, (169, 4), (4, 1))
    assert_size_stride(arg24_1, (49, 49), (49, 1))
    assert_size_stride(arg25_1, (128, 128), (128, 1))
    assert_size_stride(arg26_1, (128, ), (1, ))
    assert_size_stride(arg27_1, (128, ), (1, ))
    assert_size_stride(arg28_1, (128, ), (1, ))
    assert_size_stride(arg29_1, (512, 128), (128, 1))
    assert_size_stride(arg30_1, (512, ), (1, ))
    assert_size_stride(arg31_1, (128, 512), (512, 1))
    assert_size_stride(arg32_1, (128, ), (1, ))
    assert_size_stride(arg33_1, (512, ), (1, ))
    assert_size_stride(arg34_1, (512, ), (1, ))
    assert_size_stride(arg35_1, (256, 512), (512, 1))
    assert_size_stride(arg36_1, (128, ), (1, ))
    assert_size_stride(arg37_1, (128, ), (1, ))
    assert_size_stride(arg38_1, (256, ), (1, ))
    assert_size_stride(arg39_1, (256, ), (1, ))
    assert_size_stride(arg40_1, (768, 256), (256, 1))
    assert_size_stride(arg41_1, (768, ), (1, ))
    assert_size_stride(arg42_1, (169, 8), (8, 1))
    assert_size_stride(arg43_1, (49, 49), (49, 1))
    assert_size_stride(arg44_1, (256, 256), (256, 1))
    assert_size_stride(arg45_1, (256, ), (1, ))
    assert_size_stride(arg46_1, (256, ), (1, ))
    assert_size_stride(arg47_1, (256, ), (1, ))
    assert_size_stride(arg48_1, (1024, 256), (256, 1))
    assert_size_stride(arg49_1, (1024, ), (1, ))
    assert_size_stride(arg50_1, (256, 1024), (1024, 1))
    assert_size_stride(arg51_1, (256, ), (1, ))
    assert_size_stride(arg52_1, (256, ), (1, ))
    assert_size_stride(arg53_1, (256, ), (1, ))
    assert_size_stride(arg54_1, (768, 256), (256, 1))
    assert_size_stride(arg55_1, (768, ), (1, ))
    assert_size_stride(arg56_1, (169, 8), (8, 1))
    assert_size_stride(arg57_1, (49, 49), (49, 1))
    assert_size_stride(arg58_1, (256, 256), (256, 1))
    assert_size_stride(arg59_1, (256, ), (1, ))
    assert_size_stride(arg60_1, (256, ), (1, ))
    assert_size_stride(arg61_1, (256, ), (1, ))
    assert_size_stride(arg62_1, (1024, 256), (256, 1))
    assert_size_stride(arg63_1, (1024, ), (1, ))
    assert_size_stride(arg64_1, (256, 1024), (1024, 1))
    assert_size_stride(arg65_1, (256, ), (1, ))
    assert_size_stride(arg66_1, (1024, ), (1, ))
    assert_size_stride(arg67_1, (1024, ), (1, ))
    assert_size_stride(arg68_1, (512, 1024), (1024, 1))
    assert_size_stride(arg69_1, (256, ), (1, ))
    assert_size_stride(arg70_1, (256, ), (1, ))
    assert_size_stride(arg71_1, (512, ), (1, ))
    assert_size_stride(arg72_1, (512, ), (1, ))
    assert_size_stride(arg73_1, (1536, 512), (512, 1))
    assert_size_stride(arg74_1, (1536, ), (1, ))
    assert_size_stride(arg75_1, (169, 16), (16, 1))
    assert_size_stride(arg76_1, (49, 49), (49, 1))
    assert_size_stride(arg77_1, (512, 512), (512, 1))
    assert_size_stride(arg78_1, (512, ), (1, ))
    assert_size_stride(arg79_1, (512, ), (1, ))
    assert_size_stride(arg80_1, (512, ), (1, ))
    assert_size_stride(arg81_1, (2048, 512), (512, 1))
    assert_size_stride(arg82_1, (2048, ), (1, ))
    assert_size_stride(arg83_1, (512, 2048), (2048, 1))
    assert_size_stride(arg84_1, (512, ), (1, ))
    assert_size_stride(arg85_1, (512, ), (1, ))
    assert_size_stride(arg86_1, (512, ), (1, ))
    assert_size_stride(arg87_1, (1536, 512), (512, 1))
    assert_size_stride(arg88_1, (1536, ), (1, ))
    assert_size_stride(arg89_1, (169, 16), (16, 1))
    assert_size_stride(arg90_1, (49, 49), (49, 1))
    assert_size_stride(arg91_1, (512, 512), (512, 1))
    assert_size_stride(arg92_1, (512, ), (1, ))
    assert_size_stride(arg93_1, (512, ), (1, ))
    assert_size_stride(arg94_1, (512, ), (1, ))
    assert_size_stride(arg95_1, (2048, 512), (512, 1))
    assert_size_stride(arg96_1, (2048, ), (1, ))
    assert_size_stride(arg97_1, (512, 2048), (2048, 1))
    assert_size_stride(arg98_1, (512, ), (1, ))
    assert_size_stride(arg99_1, (512, ), (1, ))
    assert_size_stride(arg100_1, (512, ), (1, ))
    assert_size_stride(arg101_1, (1536, 512), (512, 1))
    assert_size_stride(arg102_1, (1536, ), (1, ))
    assert_size_stride(arg103_1, (169, 16), (16, 1))
    assert_size_stride(arg104_1, (49, 49), (49, 1))
    assert_size_stride(arg105_1, (512, 512), (512, 1))
    assert_size_stride(arg106_1, (512, ), (1, ))
    assert_size_stride(arg107_1, (512, ), (1, ))
    assert_size_stride(arg108_1, (512, ), (1, ))
    assert_size_stride(arg109_1, (2048, 512), (512, 1))
    assert_size_stride(arg110_1, (2048, ), (1, ))
    assert_size_stride(arg111_1, (512, 2048), (2048, 1))
    assert_size_stride(arg112_1, (512, ), (1, ))
    assert_size_stride(arg113_1, (512, ), (1, ))
    assert_size_stride(arg114_1, (512, ), (1, ))
    assert_size_stride(arg115_1, (1536, 512), (512, 1))
    assert_size_stride(arg116_1, (1536, ), (1, ))
    assert_size_stride(arg117_1, (169, 16), (16, 1))
    assert_size_stride(arg118_1, (49, 49), (49, 1))
    assert_size_stride(arg119_1, (512, 512), (512, 1))
    assert_size_stride(arg120_1, (512, ), (1, ))
    assert_size_stride(arg121_1, (512, ), (1, ))
    assert_size_stride(arg122_1, (512, ), (1, ))
    assert_size_stride(arg123_1, (2048, 512), (512, 1))
    assert_size_stride(arg124_1, (2048, ), (1, ))
    assert_size_stride(arg125_1, (512, 2048), (2048, 1))
    assert_size_stride(arg126_1, (512, ), (1, ))
    assert_size_stride(arg127_1, (512, ), (1, ))
    assert_size_stride(arg128_1, (512, ), (1, ))
    assert_size_stride(arg129_1, (1536, 512), (512, 1))
    assert_size_stride(arg130_1, (1536, ), (1, ))
    assert_size_stride(arg131_1, (169, 16), (16, 1))
    assert_size_stride(arg132_1, (49, 49), (49, 1))
    assert_size_stride(arg133_1, (512, 512), (512, 1))
    assert_size_stride(arg134_1, (512, ), (1, ))
    assert_size_stride(arg135_1, (512, ), (1, ))
    assert_size_stride(arg136_1, (512, ), (1, ))
    assert_size_stride(arg137_1, (2048, 512), (512, 1))
    assert_size_stride(arg138_1, (2048, ), (1, ))
    assert_size_stride(arg139_1, (512, 2048), (2048, 1))
    assert_size_stride(arg140_1, (512, ), (1, ))
    assert_size_stride(arg141_1, (512, ), (1, ))
    assert_size_stride(arg142_1, (512, ), (1, ))
    assert_size_stride(arg143_1, (1536, 512), (512, 1))
    assert_size_stride(arg144_1, (1536, ), (1, ))
    assert_size_stride(arg145_1, (169, 16), (16, 1))
    assert_size_stride(arg146_1, (49, 49), (49, 1))
    assert_size_stride(arg147_1, (512, 512), (512, 1))
    assert_size_stride(arg148_1, (512, ), (1, ))
    assert_size_stride(arg149_1, (512, ), (1, ))
    assert_size_stride(arg150_1, (512, ), (1, ))
    assert_size_stride(arg151_1, (2048, 512), (512, 1))
    assert_size_stride(arg152_1, (2048, ), (1, ))
    assert_size_stride(arg153_1, (512, 2048), (2048, 1))
    assert_size_stride(arg154_1, (512, ), (1, ))
    assert_size_stride(arg155_1, (512, ), (1, ))
    assert_size_stride(arg156_1, (512, ), (1, ))
    assert_size_stride(arg157_1, (1536, 512), (512, 1))
    assert_size_stride(arg158_1, (1536, ), (1, ))
    assert_size_stride(arg159_1, (169, 16), (16, 1))
    assert_size_stride(arg160_1, (49, 49), (49, 1))
    assert_size_stride(arg161_1, (512, 512), (512, 1))
    assert_size_stride(arg162_1, (512, ), (1, ))
    assert_size_stride(arg163_1, (512, ), (1, ))
    assert_size_stride(arg164_1, (512, ), (1, ))
    assert_size_stride(arg165_1, (2048, 512), (512, 1))
    assert_size_stride(arg166_1, (2048, ), (1, ))
    assert_size_stride(arg167_1, (512, 2048), (2048, 1))
    assert_size_stride(arg168_1, (512, ), (1, ))
    assert_size_stride(arg169_1, (512, ), (1, ))
    assert_size_stride(arg170_1, (512, ), (1, ))
    assert_size_stride(arg171_1, (1536, 512), (512, 1))
    assert_size_stride(arg172_1, (1536, ), (1, ))
    assert_size_stride(arg173_1, (169, 16), (16, 1))
    assert_size_stride(arg174_1, (49, 49), (49, 1))
    assert_size_stride(arg175_1, (512, 512), (512, 1))
    assert_size_stride(arg176_1, (512, ), (1, ))
    assert_size_stride(arg177_1, (512, ), (1, ))
    assert_size_stride(arg178_1, (512, ), (1, ))
    assert_size_stride(arg179_1, (2048, 512), (512, 1))
    assert_size_stride(arg180_1, (2048, ), (1, ))
    assert_size_stride(arg181_1, (512, 2048), (2048, 1))
    assert_size_stride(arg182_1, (512, ), (1, ))
    assert_size_stride(arg183_1, (512, ), (1, ))
    assert_size_stride(arg184_1, (512, ), (1, ))
    assert_size_stride(arg185_1, (1536, 512), (512, 1))
    assert_size_stride(arg186_1, (1536, ), (1, ))
    assert_size_stride(arg187_1, (169, 16), (16, 1))
    assert_size_stride(arg188_1, (49, 49), (49, 1))
    assert_size_stride(arg189_1, (512, 512), (512, 1))
    assert_size_stride(arg190_1, (512, ), (1, ))
    assert_size_stride(arg191_1, (512, ), (1, ))
    assert_size_stride(arg192_1, (512, ), (1, ))
    assert_size_stride(arg193_1, (2048, 512), (512, 1))
    assert_size_stride(arg194_1, (2048, ), (1, ))
    assert_size_stride(arg195_1, (512, 2048), (2048, 1))
    assert_size_stride(arg196_1, (512, ), (1, ))
    assert_size_stride(arg197_1, (512, ), (1, ))
    assert_size_stride(arg198_1, (512, ), (1, ))
    assert_size_stride(arg199_1, (1536, 512), (512, 1))
    assert_size_stride(arg200_1, (1536, ), (1, ))
    assert_size_stride(arg201_1, (169, 16), (16, 1))
    assert_size_stride(arg202_1, (49, 49), (49, 1))
    assert_size_stride(arg203_1, (512, 512), (512, 1))
    assert_size_stride(arg204_1, (512, ), (1, ))
    assert_size_stride(arg205_1, (512, ), (1, ))
    assert_size_stride(arg206_1, (512, ), (1, ))
    assert_size_stride(arg207_1, (2048, 512), (512, 1))
    assert_size_stride(arg208_1, (2048, ), (1, ))
    assert_size_stride(arg209_1, (512, 2048), (2048, 1))
    assert_size_stride(arg210_1, (512, ), (1, ))
    assert_size_stride(arg211_1, (512, ), (1, ))
    assert_size_stride(arg212_1, (512, ), (1, ))
    assert_size_stride(arg213_1, (1536, 512), (512, 1))
    assert_size_stride(arg214_1, (1536, ), (1, ))
    assert_size_stride(arg215_1, (169, 16), (16, 1))
    assert_size_stride(arg216_1, (49, 49), (49, 1))
    assert_size_stride(arg217_1, (512, 512), (512, 1))
    assert_size_stride(arg218_1, (512, ), (1, ))
    assert_size_stride(arg219_1, (512, ), (1, ))
    assert_size_stride(arg220_1, (512, ), (1, ))
    assert_size_stride(arg221_1, (2048, 512), (512, 1))
    assert_size_stride(arg222_1, (2048, ), (1, ))
    assert_size_stride(arg223_1, (512, 2048), (2048, 1))
    assert_size_stride(arg224_1, (512, ), (1, ))
    assert_size_stride(arg225_1, (512, ), (1, ))
    assert_size_stride(arg226_1, (512, ), (1, ))
    assert_size_stride(arg227_1, (1536, 512), (512, 1))
    assert_size_stride(arg228_1, (1536, ), (1, ))
    assert_size_stride(arg229_1, (169, 16), (16, 1))
    assert_size_stride(arg230_1, (49, 49), (49, 1))
    assert_size_stride(arg231_1, (512, 512), (512, 1))
    assert_size_stride(arg232_1, (512, ), (1, ))
    assert_size_stride(arg233_1, (512, ), (1, ))
    assert_size_stride(arg234_1, (512, ), (1, ))
    assert_size_stride(arg235_1, (2048, 512), (512, 1))
    assert_size_stride(arg236_1, (2048, ), (1, ))
    assert_size_stride(arg237_1, (512, 2048), (2048, 1))
    assert_size_stride(arg238_1, (512, ), (1, ))
    assert_size_stride(arg239_1, (512, ), (1, ))
    assert_size_stride(arg240_1, (512, ), (1, ))
    assert_size_stride(arg241_1, (1536, 512), (512, 1))
    assert_size_stride(arg242_1, (1536, ), (1, ))
    assert_size_stride(arg243_1, (169, 16), (16, 1))
    assert_size_stride(arg244_1, (49, 49), (49, 1))
    assert_size_stride(arg245_1, (512, 512), (512, 1))
    assert_size_stride(arg246_1, (512, ), (1, ))
    assert_size_stride(arg247_1, (512, ), (1, ))
    assert_size_stride(arg248_1, (512, ), (1, ))
    assert_size_stride(arg249_1, (2048, 512), (512, 1))
    assert_size_stride(arg250_1, (2048, ), (1, ))
    assert_size_stride(arg251_1, (512, 2048), (2048, 1))
    assert_size_stride(arg252_1, (512, ), (1, ))
    assert_size_stride(arg253_1, (512, ), (1, ))
    assert_size_stride(arg254_1, (512, ), (1, ))
    assert_size_stride(arg255_1, (1536, 512), (512, 1))
    assert_size_stride(arg256_1, (1536, ), (1, ))
    assert_size_stride(arg257_1, (169, 16), (16, 1))
    assert_size_stride(arg258_1, (49, 49), (49, 1))
    assert_size_stride(arg259_1, (512, 512), (512, 1))
    assert_size_stride(arg260_1, (512, ), (1, ))
    assert_size_stride(arg261_1, (512, ), (1, ))
    assert_size_stride(arg262_1, (512, ), (1, ))
    assert_size_stride(arg263_1, (2048, 512), (512, 1))
    assert_size_stride(arg264_1, (2048, ), (1, ))
    assert_size_stride(arg265_1, (512, 2048), (2048, 1))
    assert_size_stride(arg266_1, (512, ), (1, ))
    assert_size_stride(arg267_1, (512, ), (1, ))
    assert_size_stride(arg268_1, (512, ), (1, ))
    assert_size_stride(arg269_1, (1536, 512), (512, 1))
    assert_size_stride(arg270_1, (1536, ), (1, ))
    assert_size_stride(arg271_1, (169, 16), (16, 1))
    assert_size_stride(arg272_1, (49, 49), (49, 1))
    assert_size_stride(arg273_1, (512, 512), (512, 1))
    assert_size_stride(arg274_1, (512, ), (1, ))
    assert_size_stride(arg275_1, (512, ), (1, ))
    assert_size_stride(arg276_1, (512, ), (1, ))
    assert_size_stride(arg277_1, (2048, 512), (512, 1))
    assert_size_stride(arg278_1, (2048, ), (1, ))
    assert_size_stride(arg279_1, (512, 2048), (2048, 1))
    assert_size_stride(arg280_1, (512, ), (1, ))
    assert_size_stride(arg281_1, (512, ), (1, ))
    assert_size_stride(arg282_1, (512, ), (1, ))
    assert_size_stride(arg283_1, (1536, 512), (512, 1))
    assert_size_stride(arg284_1, (1536, ), (1, ))
    assert_size_stride(arg285_1, (169, 16), (16, 1))
    assert_size_stride(arg286_1, (49, 49), (49, 1))
    assert_size_stride(arg287_1, (512, 512), (512, 1))
    assert_size_stride(arg288_1, (512, ), (1, ))
    assert_size_stride(arg289_1, (512, ), (1, ))
    assert_size_stride(arg290_1, (512, ), (1, ))
    assert_size_stride(arg291_1, (2048, 512), (512, 1))
    assert_size_stride(arg292_1, (2048, ), (1, ))
    assert_size_stride(arg293_1, (512, 2048), (2048, 1))
    assert_size_stride(arg294_1, (512, ), (1, ))
    assert_size_stride(arg295_1, (512, ), (1, ))
    assert_size_stride(arg296_1, (512, ), (1, ))
    assert_size_stride(arg297_1, (1536, 512), (512, 1))
    assert_size_stride(arg298_1, (1536, ), (1, ))
    assert_size_stride(arg299_1, (169, 16), (16, 1))
    assert_size_stride(arg300_1, (49, 49), (49, 1))
    assert_size_stride(arg301_1, (512, 512), (512, 1))
    assert_size_stride(arg302_1, (512, ), (1, ))
    assert_size_stride(arg303_1, (512, ), (1, ))
    assert_size_stride(arg304_1, (512, ), (1, ))
    assert_size_stride(arg305_1, (2048, 512), (512, 1))
    assert_size_stride(arg306_1, (2048, ), (1, ))
    assert_size_stride(arg307_1, (512, 2048), (2048, 1))
    assert_size_stride(arg308_1, (512, ), (1, ))
    assert_size_stride(arg309_1, (512, ), (1, ))
    assert_size_stride(arg310_1, (512, ), (1, ))
    assert_size_stride(arg311_1, (1536, 512), (512, 1))
    assert_size_stride(arg312_1, (1536, ), (1, ))
    assert_size_stride(arg313_1, (169, 16), (16, 1))
    assert_size_stride(arg314_1, (49, 49), (49, 1))
    assert_size_stride(arg315_1, (512, 512), (512, 1))
    assert_size_stride(arg316_1, (512, ), (1, ))
    assert_size_stride(arg317_1, (512, ), (1, ))
    assert_size_stride(arg318_1, (512, ), (1, ))
    assert_size_stride(arg319_1, (2048, 512), (512, 1))
    assert_size_stride(arg320_1, (2048, ), (1, ))
    assert_size_stride(arg321_1, (512, 2048), (2048, 1))
    assert_size_stride(arg322_1, (512, ), (1, ))
    assert_size_stride(arg323_1, (512, ), (1, ))
    assert_size_stride(arg324_1, (512, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1, 3, 1088, 1088), (3551232, 1183744, 1088, 1), torch.float16)
        # Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
        stream0 = get_raw_stream(0)
        triton_poi_fused__to_copy_convolution_0.run(arg0_1, buf0, 3551232, grid=grid(3551232), stream=stream0)
        del arg0_1
        buf1 = empty_strided_cuda((128, 3, 4, 4), (48, 16, 4, 1), torch.float16)
        # Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
        triton_poi_fused__to_copy_convolution_1.run(arg1_1, buf1, 6144, grid=grid(6144), stream=stream0)
        del arg1_1
        buf2 = empty_strided_cuda((128, ), (1, ), torch.float16)
        # Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
        triton_poi_fused__to_copy_convolution_2.run(arg2_1, buf2, 128, grid=grid(128), stream=stream0)
        del arg2_1
        # Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
        buf3 = extern_kernels.convolution(buf0, buf1, stride=(4, 4), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
        assert_size_stride(buf3, (1, 128, 272, 272), (9469952, 73984, 272, 1))
        del buf0
        del buf1
        buf4 = empty_strided_cuda((1, 73984, 1), (73984, 1, 73984), torch.float32)
        buf5 = empty_strided_cuda((1, 73984, 1), (73984, 1, 73984), torch.float32)
        # Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_red_fused__to_copy_native_layer_norm_3.run(buf3, buf2, buf4, buf5, 73984, 128, grid=grid(73984), stream=stream0)
        buf7 = empty_strided_cuda((1, 73984, 128), (9469952, 1, 73984), torch.float32)
        # Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_poi_fused__to_copy_native_layer_norm_4.run(buf3, buf2, buf4, buf5, arg3_1, arg4_1, buf7, 9469952, grid=grid(9469952), stream=stream0)
        del arg3_1
        del arg4_1
        del buf2
        buf8 = buf5; del buf5  # reuse
        buf9 = buf4; del buf4  # reuse
        # Source Nodes: [x_7], Original ATen: [aten.native_layer_norm]
        triton_red_fused_native_layer_norm_5.run(buf7, buf8, buf9, 73984, 128, grid=grid(73984), stream=stream0)
        buf11 = empty_strided_cuda((1521, 49, 128), (6272, 128, 1), torch.float16)
        # Source Nodes: [linear], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_6.run(buf7, buf8, buf9, arg5_1, arg6_1, buf11, 74529, 128, grid=grid(74529, 128), stream=stream0)
        del arg5_1
        del arg6_1
        buf12 = empty_strided_cuda((384, 128), (128, 1), torch.float16)
        # Source Nodes: [linear], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_7.run(arg7_1, buf12, 49152, grid=grid(49152), stream=stream0)
        del arg7_1
        buf13 = empty_strided_cuda((74529, 384), (384, 1), torch.float16)
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf11, (74529, 128), (128, 1), 0), reinterpret_tensor(buf12, (128, 384), (1, 128), 0), out=buf13)
        buf14 = empty_strided_cuda((1521, 4, 49, 32), (6400, 1600, 32, 1), torch.float16)
        # Source Nodes: [attn, q_1], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_8.run(buf13, arg8_1, buf14, 9539712, grid=grid(9539712), stream=stream0)
        buf15 = empty_strided_cuda((1521, 4, 32, 49), (6400, 1600, 49, 1), torch.float16)
        # Source Nodes: [attn], Original ATen: [aten.clone]
        triton_poi_fused_clone_9.run(buf13, arg8_1, buf15, 194688, 49, grid=grid(194688, 49), stream=stream0)
        buf16 = empty_strided_cuda((6084, 49, 49), (2401, 49, 1), torch.float16)
        # Source Nodes: [attn], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf14, (6084, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf15, (6084, 32, 49), (1600, 49, 1), 0), out=buf16)
        buf19 = empty_strided_cuda((1521, 4, 49, 49), (9728, 2432, 49, 1), torch.float16)
        # Source Nodes: [attn_1, attn_2, matmul_1], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_10.run(buf16, arg10_1, arg9_1, buf19, 298116, 49, grid=grid(298116), stream=stream0)
        del arg10_1
        del arg9_1
        buf20 = reinterpret_tensor(buf15, (1521, 4, 49, 32), (6400, 1600, 32, 1), 0); del buf15  # reuse
        # Source Nodes: [matmul_1], Original ATen: [aten.clone]
        triton_poi_fused_clone_11.run(buf13, arg8_1, buf20, 9539712, grid=grid(9539712), stream=stream0)
        del arg8_1
        buf21 = reinterpret_tensor(buf11, (6084, 49, 32), (1568, 32, 1), 0); del buf11  # reuse
        # Source Nodes: [matmul_1], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf19, (6084, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf20, (6084, 49, 32), (1600, 32, 1), 0), out=buf21)
        buf22 = empty_strided_cuda((1521, 49, 4, 32), (6272, 128, 32, 1), torch.float16)
        # Source Nodes: [x_11], Original ATen: [aten.clone]
        triton_poi_fused_clone_12.run(buf21, buf22, 9539712, grid=grid(9539712), stream=stream0)
        buf23 = empty_strided_cuda((128, 128), (128, 1), torch.float16)
        # Source Nodes: [x_12], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_13.run(arg11_1, buf23, 16384, grid=grid(16384), stream=stream0)
        del arg11_1
        buf24 = reinterpret_tensor(buf21, (74529, 128), (128, 1), 0); del buf21  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf22, (74529, 128), (128, 1), 0), reinterpret_tensor(buf23, (128, 128), (1, 128), 0), out=buf24)
        buf28 = reinterpret_tensor(buf3, (1, 73984, 128), (9469952, 128, 1), 0); del buf3  # reuse
        # Source Nodes: [layer_norm_2, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_14.run(buf7, buf24, arg12_1, arg13_1, arg14_1, buf28, 73984, 128, grid=grid(73984), stream=stream0)
        del arg13_1
        del arg14_1
        buf29 = empty_strided_cuda((512, 128), (128, 1), torch.float16)
        # Source Nodes: [x_19], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg15_1, buf29, 65536, grid=grid(65536), stream=stream0)
        del arg15_1
        buf30 = empty_strided_cuda((73984, 512), (512, 1), torch.float16)
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf28, (73984, 128), (128, 1), 0), reinterpret_tensor(buf29, (128, 512), (1, 128), 0), out=buf30)
        buf31 = reinterpret_tensor(buf30, (1, 73984, 512), (37879808, 512, 1), 0); del buf30  # reuse
        # Source Nodes: [x_20], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_16.run(buf31, arg16_1, 37879808, grid=grid(37879808), stream=stream0)
        del arg16_1
        buf32 = reinterpret_tensor(buf29, (128, 512), (512, 1), 0); del buf29  # reuse
        # Source Nodes: [x_22], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg17_1, buf32, 65536, grid=grid(65536), stream=stream0)
        del arg17_1
        buf33 = reinterpret_tensor(buf28, (73984, 128), (128, 1), 0); del buf28  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf31, (73984, 512), (512, 1), 0), reinterpret_tensor(buf32, (512, 128), (1, 512), 0), out=buf33)
        buf34 = empty_strided_cuda((1, 73984, 128), (9469952, 128, 1), torch.float32)
        buf35 = buf9; del buf9  # reuse
        buf36 = buf8; del buf8  # reuse
        # Source Nodes: [x_18, x_24, x_25], Original ATen: [aten.add, aten.native_layer_norm]
        triton_red_fused_add_native_layer_norm_17.run(buf7, buf24, arg12_1, buf33, arg18_1, buf34, buf35, buf36, 73984, 128, grid=grid(73984), stream=stream0)
        del arg12_1
        del arg18_1
        buf38 = empty_strided_cuda((1, 273, 273, 128), (9539712, 34944, 128, 1), torch.float32)
        # Source Nodes: [shifted_x, x_27], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_18.run(buf34, buf35, buf36, arg19_1, arg20_1, buf38, 9539712, grid=grid(9539712), stream=stream0)
        del arg19_1
        del arg20_1
        buf39 = reinterpret_tensor(buf24, (1521, 49, 128), (6272, 128, 1), 0); del buf24  # reuse
        # Source Nodes: [linear_4], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_19.run(buf38, buf39, 9539712, grid=grid(9539712), stream=stream0)
        del buf38
        buf40 = buf12; del buf12  # reuse
        # Source Nodes: [linear_4], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_7.run(arg21_1, buf40, 49152, grid=grid(49152), stream=stream0)
        del arg21_1
        buf41 = buf13; del buf13  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf39, (74529, 128), (128, 1), 0), reinterpret_tensor(buf40, (128, 384), (1, 128), 0), out=buf41)
        del buf40
        buf42 = buf20; del buf20  # reuse
        # Source Nodes: [attn_4, q_3], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_8.run(buf41, arg22_1, buf42, 9539712, grid=grid(9539712), stream=stream0)
        buf43 = reinterpret_tensor(buf14, (1521, 4, 32, 49), (6400, 1600, 49, 1), 0); del buf14  # reuse
        # Source Nodes: [attn_4], Original ATen: [aten.clone]
        triton_poi_fused_clone_9.run(buf41, arg22_1, buf43, 194688, 49, grid=grid(194688, 49), stream=stream0)
        buf44 = buf16; del buf16  # reuse
        # Source Nodes: [attn_4], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf42, (6084, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf43, (6084, 32, 49), (1600, 49, 1), 0), out=buf44)
        del buf42
        buf45 = empty_strided_cuda((1, 273, 273, 1), (74560, 273, 1, 1), torch.float32)
        buf46 = reinterpret_tensor(buf45, (1, 273, 273, 1), (74560, 273, 1, 74560), 0); del buf45  # reuse
        # Source Nodes: [img_mask, setitem, setitem_1, setitem_2, setitem_3, setitem_4], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
        triton_poi_fused_fill_lift_fresh_slice_zeros_20.run(buf46, 74529, grid=grid(74529), stream=stream0)
        buf47 = empty_strided_cuda((1, 3, 273, 1), (819, 273, 1, 819), torch.float32)
        # Source Nodes: [setitem_8], Original ATen: [aten.fill, aten.lift_fresh]
        triton_poi_fused_fill_lift_fresh_21.run(buf46, buf47, 819, grid=grid(819), stream=stream0)
        buf48 = reinterpret_tensor(buf46, (1, 273, 273, 1), (74560, 273, 1, 1), 0); del buf46  # reuse
        # Source Nodes: [setitem_5, setitem_6, setitem_7], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
        triton_poi_fused_fill_lift_fresh_slice_22.run(buf48, buf47, 74529, grid=grid(74529), stream=stream0)
        del buf47
        buf52 = buf19; del buf19  # reuse
        # Source Nodes: [attn_6, attn_8, matmul_3], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_23.run(buf44, arg24_1, arg23_1, buf48, buf52, 298116, 49, grid=grid(298116), stream=stream0)
        del arg23_1
        del arg24_1
        del buf44
        del buf48
        buf53 = reinterpret_tensor(buf43, (1521, 4, 49, 32), (6400, 1600, 32, 1), 0); del buf43  # reuse
        # Source Nodes: [matmul_3], Original ATen: [aten.clone]
        triton_poi_fused_clone_11.run(buf41, arg22_1, buf53, 9539712, grid=grid(9539712), stream=stream0)
        del arg22_1
        del buf41
        buf54 = reinterpret_tensor(buf39, (6084, 49, 32), (1568, 32, 1), 0); del buf39  # reuse
        # Source Nodes: [matmul_3], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf52, (6084, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf53, (6084, 49, 32), (1600, 32, 1), 0), out=buf54)
        del buf52
        del buf53
        buf55 = buf22; del buf22  # reuse
        # Source Nodes: [x_29], Original ATen: [aten.clone]
        triton_poi_fused_clone_12.run(buf54, buf55, 9539712, grid=grid(9539712), stream=stream0)
        buf56 = buf23; del buf23  # reuse
        # Source Nodes: [x_30], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_13.run(arg25_1, buf56, 16384, grid=grid(16384), stream=stream0)
        del arg25_1
        buf57 = reinterpret_tensor(buf54, (74529, 128), (128, 1), 0); del buf54  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf55, (74529, 128), (128, 1), 0), reinterpret_tensor(buf56, (128, 128), (1, 128), 0), out=buf57)
        del buf55
        del buf56
        buf62 = reinterpret_tensor(buf33, (1, 73984, 128), (9469952, 128, 1), 0); del buf33  # reuse
        # Source Nodes: [layer_norm_4, x_37, x_38], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_24.run(buf34, buf57, arg26_1, arg27_1, arg28_1, buf62, 73984, 128, grid=grid(73984), stream=stream0)
        del arg27_1
        del arg28_1
        buf63 = reinterpret_tensor(buf32, (512, 128), (128, 1), 0); del buf32  # reuse
        # Source Nodes: [x_38], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg29_1, buf63, 65536, grid=grid(65536), stream=stream0)
        del arg29_1
        buf64 = reinterpret_tensor(buf31, (73984, 512), (512, 1), 0); del buf31  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf62, (73984, 128), (128, 1), 0), reinterpret_tensor(buf63, (128, 512), (1, 128), 0), out=buf64)
        buf65 = reinterpret_tensor(buf64, (1, 73984, 512), (37879808, 512, 1), 0); del buf64  # reuse
        # Source Nodes: [x_39], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_16.run(buf65, arg30_1, 37879808, grid=grid(37879808), stream=stream0)
        del arg30_1
        buf66 = reinterpret_tensor(buf63, (128, 512), (512, 1), 0); del buf63  # reuse
        # Source Nodes: [x_41], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg31_1, buf66, 65536, grid=grid(65536), stream=stream0)
        del arg31_1
        buf67 = reinterpret_tensor(buf62, (73984, 128), (128, 1), 0); del buf62  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf65, (73984, 512), (512, 1), 0), reinterpret_tensor(buf66, (512, 128), (1, 512), 0), out=buf67)
        del buf65
        buf68 = reinterpret_tensor(buf7, (1, 73984, 128), (9469952, 128, 1), 0); del buf7  # reuse
        buf72 = buf36; del buf36  # reuse
        buf73 = buf35; del buf35  # reuse
        # Source Nodes: [x_37, x_43, x_out], Original ATen: [aten.add, aten.native_layer_norm]
        triton_per_fused_add_native_layer_norm_25.run(buf34, buf57, arg26_1, buf67, arg32_1, buf68, buf72, buf73, 73984, 128, grid=grid(73984), stream=stream0)
        del arg26_1
        del arg32_1
        del buf57
        buf76 = reinterpret_tensor(buf67, (1, 18496, 512), (9469952, 512, 1), 0); del buf67  # reuse
        # Source Nodes: [x_47, x_48], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_red_fused__to_copy_native_layer_norm_26.run(buf68, arg33_1, arg34_1, buf76, 18496, 512, grid=grid(18496), stream=stream0)
        del arg33_1
        del arg34_1
        buf77 = empty_strided_cuda((256, 512), (512, 1), torch.float16)
        # Source Nodes: [x_48], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_27.run(arg35_1, buf77, 131072, grid=grid(131072), stream=stream0)
        del arg35_1
        buf78 = empty_strided_cuda((18496, 256), (256, 1), torch.float16)
        # Source Nodes: [x_48], Original ATen: [aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf76, (18496, 512), (512, 1), 0), reinterpret_tensor(buf77, (512, 256), (1, 512), 0), out=buf78)
        del buf77
        buf79 = empty_strided_cuda((1, 18496, 1), (18496, 1, 18496), torch.float32)
        buf80 = empty_strided_cuda((1, 18496, 1), (18496, 1, 18496), torch.float32)
        # Source Nodes: [x_50], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_per_fused__to_copy_native_layer_norm_28.run(buf78, buf79, buf80, 18496, 256, grid=grid(18496), stream=stream0)
        buf82 = empty_strided_cuda((400, 49, 256), (12544, 256, 1), torch.float16)
        # Source Nodes: [linear_9], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_29.run(buf78, buf79, buf80, arg38_1, arg39_1, buf82, 5017600, grid=grid(5017600), stream=stream0)
        del arg38_1
        del arg39_1
        buf83 = empty_strided_cuda((768, 256), (256, 1), torch.float16)
        # Source Nodes: [linear_9], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_30.run(arg40_1, buf83, 196608, grid=grid(196608), stream=stream0)
        del arg40_1
        buf84 = empty_strided_cuda((19600, 768), (768, 1), torch.float16)
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf82, (19600, 256), (256, 1), 0), reinterpret_tensor(buf83, (256, 768), (1, 256), 0), out=buf84)
        buf85 = empty_strided_cuda((400, 8, 49, 32), (12800, 1600, 32, 1), torch.float16)
        # Source Nodes: [attn_10, q_5], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_31.run(buf84, arg41_1, buf85, 5017600, grid=grid(5017600), stream=stream0)
        buf86 = empty_strided_cuda((400, 8, 32, 49), (12800, 1600, 49, 1), torch.float16)
        # Source Nodes: [attn_10], Original ATen: [aten.clone]
        triton_poi_fused_clone_32.run(buf84, arg41_1, buf86, 102400, 49, grid=grid(102400, 49), stream=stream0)
        buf87 = empty_strided_cuda((3200, 49, 49), (2401, 49, 1), torch.float16)
        # Source Nodes: [attn_10], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf85, (3200, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf86, (3200, 32, 49), (1600, 49, 1), 0), out=buf87)
        buf90 = empty_strided_cuda((400, 8, 49, 49), (19456, 2432, 49, 1), torch.float16)
        # Source Nodes: [attn_11, attn_12, matmul_5], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_33.run(buf87, arg43_1, arg42_1, buf90, 156800, 49, grid=grid(156800), stream=stream0)
        del arg42_1
        del arg43_1
        buf91 = reinterpret_tensor(buf86, (400, 8, 49, 32), (12800, 1600, 32, 1), 0); del buf86  # reuse
        # Source Nodes: [matmul_5], Original ATen: [aten.clone]
        triton_poi_fused_clone_34.run(buf84, arg41_1, buf91, 5017600, grid=grid(5017600), stream=stream0)
        del arg41_1
        buf92 = reinterpret_tensor(buf82, (3200, 49, 32), (1568, 32, 1), 0); del buf82  # reuse
        # Source Nodes: [matmul_5], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf90, (3200, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf91, (3200, 49, 32), (1600, 32, 1), 0), out=buf92)
        buf93 = empty_strided_cuda((400, 49, 8, 32), (12544, 256, 32, 1), torch.float16)
        # Source Nodes: [x_54], Original ATen: [aten.clone]
        triton_poi_fused_clone_35.run(buf92, buf93, 5017600, grid=grid(5017600), stream=stream0)
        buf94 = reinterpret_tensor(buf66, (256, 256), (256, 1), 0); del buf66  # reuse
        # Source Nodes: [x_55], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg44_1, buf94, 65536, grid=grid(65536), stream=stream0)
        del arg44_1
        buf95 = reinterpret_tensor(buf92, (19600, 256), (256, 1), 0); del buf92  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf93, (19600, 256), (256, 1), 0), reinterpret_tensor(buf94, (256, 256), (1, 256), 0), out=buf95)
        buf99 = empty_strided_cuda((1, 18496, 256), (4734976, 256, 1), torch.float16)
        # Source Nodes: [layer_norm_8, x_61, x_62], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_36.run(buf78, buf95, arg45_1, arg46_1, arg47_1, buf99, 18496, 256, grid=grid(18496), stream=stream0)
        del arg46_1
        del arg47_1
        buf100 = empty_strided_cuda((1024, 256), (256, 1), torch.float16)
        # Source Nodes: [x_62], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg48_1, buf100, 262144, grid=grid(262144), stream=stream0)
        del arg48_1
        buf101 = empty_strided_cuda((18496, 1024), (1024, 1), torch.float16)
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf99, (18496, 256), (256, 1), 0), reinterpret_tensor(buf100, (256, 1024), (1, 256), 0), out=buf101)
        buf102 = reinterpret_tensor(buf101, (1, 18496, 1024), (18939904, 1024, 1), 0); del buf101  # reuse
        # Source Nodes: [x_63], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_38.run(buf102, arg49_1, 18939904, grid=grid(18939904), stream=stream0)
        del arg49_1
        buf103 = reinterpret_tensor(buf100, (256, 1024), (1024, 1), 0); del buf100  # reuse
        # Source Nodes: [x_65], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg50_1, buf103, 262144, grid=grid(262144), stream=stream0)
        del arg50_1
        buf104 = reinterpret_tensor(buf99, (18496, 256), (256, 1), 0); del buf99  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf102, (18496, 1024), (1024, 1), 0), reinterpret_tensor(buf103, (1024, 256), (1, 1024), 0), out=buf104)
        buf105 = reinterpret_tensor(buf104, (1, 18496, 256), (4734976, 256, 1), 0); del buf104  # reuse
        buf106 = buf80; del buf80  # reuse
        buf107 = buf79; del buf79  # reuse
        # Source Nodes: [x_61, x_67, x_68], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_39.run(buf105, buf78, buf95, arg45_1, arg51_1, buf106, buf107, 18496, 256, grid=grid(18496), stream=stream0)
        del arg45_1
        del arg51_1
        buf109 = empty_strided_cuda((1, 140, 140, 256), (5017600, 35840, 256, 1), torch.float32)
        # Source Nodes: [shifted_x_1, x_70], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_40.run(buf105, buf106, buf107, arg52_1, arg53_1, buf109, 5017600, grid=grid(5017600), stream=stream0)
        del arg52_1
        del arg53_1
        buf110 = reinterpret_tensor(buf95, (400, 49, 256), (12544, 256, 1), 0); del buf95  # reuse
        # Source Nodes: [linear_13], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_41.run(buf109, buf110, 5017600, grid=grid(5017600), stream=stream0)
        del buf109
        buf111 = buf83; del buf83  # reuse
        # Source Nodes: [linear_13], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_30.run(arg54_1, buf111, 196608, grid=grid(196608), stream=stream0)
        del arg54_1
        buf112 = buf84; del buf84  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf110, (19600, 256), (256, 1), 0), reinterpret_tensor(buf111, (256, 768), (1, 256), 0), out=buf112)
        del buf111
        buf113 = buf91; del buf91  # reuse
        # Source Nodes: [attn_14, q_7], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_31.run(buf112, arg55_1, buf113, 5017600, grid=grid(5017600), stream=stream0)
        buf114 = reinterpret_tensor(buf85, (400, 8, 32, 49), (12800, 1600, 49, 1), 0); del buf85  # reuse
        # Source Nodes: [attn_14], Original ATen: [aten.clone]
        triton_poi_fused_clone_32.run(buf112, arg55_1, buf114, 102400, 49, grid=grid(102400, 49), stream=stream0)
        buf115 = buf87; del buf87  # reuse
        # Source Nodes: [attn_14], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf113, (3200, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf114, (3200, 32, 49), (1600, 49, 1), 0), out=buf115)
        del buf113
        buf116 = empty_strided_cuda((1, 140, 140, 1), (19616, 140, 1, 1), torch.float32)
        buf117 = reinterpret_tensor(buf116, (1, 140, 140, 1), (19616, 140, 1, 19616), 0); del buf116  # reuse
        # Source Nodes: [img_mask_1, setitem_10, setitem_11, setitem_12, setitem_13, setitem_9], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
        triton_poi_fused_fill_lift_fresh_slice_zeros_42.run(buf117, 19600, grid=grid(19600), stream=stream0)
        buf118 = empty_strided_cuda((1, 3, 140, 1), (420, 140, 1, 420), torch.float32)
        # Source Nodes: [setitem_17], Original ATen: [aten.fill, aten.lift_fresh]
        triton_poi_fused_fill_lift_fresh_43.run(buf117, buf118, 420, grid=grid(420), stream=stream0)
        buf119 = reinterpret_tensor(buf117, (1, 140, 140, 1), (19616, 140, 1, 1), 0); del buf117  # reuse
        # Source Nodes: [setitem_14, setitem_15, setitem_16], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
        triton_poi_fused_fill_lift_fresh_slice_44.run(buf119, buf118, 19600, grid=grid(19600), stream=stream0)
        del buf118
        buf123 = buf90; del buf90  # reuse
        # Source Nodes: [attn_16, attn_18, matmul_7], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_45.run(buf115, arg57_1, arg56_1, buf119, buf123, 156800, 49, grid=grid(156800), stream=stream0)
        del arg56_1
        del arg57_1
        del buf115
        del buf119
        buf124 = reinterpret_tensor(buf114, (400, 8, 49, 32), (12800, 1600, 32, 1), 0); del buf114  # reuse
        # Source Nodes: [matmul_7], Original ATen: [aten.clone]
        triton_poi_fused_clone_34.run(buf112, arg55_1, buf124, 5017600, grid=grid(5017600), stream=stream0)
        del arg55_1
        del buf112
        buf125 = reinterpret_tensor(buf110, (3200, 49, 32), (1568, 32, 1), 0); del buf110  # reuse
        # Source Nodes: [matmul_7], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf123, (3200, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf124, (3200, 49, 32), (1600, 32, 1), 0), out=buf125)
        del buf123
        del buf124
        buf126 = buf93; del buf93  # reuse
        # Source Nodes: [x_72], Original ATen: [aten.clone]
        triton_poi_fused_clone_35.run(buf125, buf126, 5017600, grid=grid(5017600), stream=stream0)
        buf127 = buf94; del buf94  # reuse
        # Source Nodes: [x_73], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_15.run(arg58_1, buf127, 65536, grid=grid(65536), stream=stream0)
        del arg58_1
        buf128 = reinterpret_tensor(buf125, (19600, 256), (256, 1), 0); del buf125  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf126, (19600, 256), (256, 1), 0), reinterpret_tensor(buf127, (256, 256), (1, 256), 0), out=buf128)
        del buf126
        del buf127
        buf133 = reinterpret_tensor(buf78, (1, 18496, 256), (4734976, 256, 1), 0); del buf78  # reuse
        # Source Nodes: [layer_norm_10, x_80, x_81], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_46.run(buf105, buf128, arg59_1, arg60_1, arg61_1, buf133, 18496, 256, grid=grid(18496), stream=stream0)
        del arg60_1
        del arg61_1
        buf134 = reinterpret_tensor(buf103, (1024, 256), (256, 1), 0); del buf103  # reuse
        # Source Nodes: [x_81], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg62_1, buf134, 262144, grid=grid(262144), stream=stream0)
        del arg62_1
        buf135 = reinterpret_tensor(buf102, (18496, 1024), (1024, 1), 0); del buf102  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf133, (18496, 256), (256, 1), 0), reinterpret_tensor(buf134, (256, 1024), (1, 256), 0), out=buf135)
        buf136 = reinterpret_tensor(buf135, (1, 18496, 1024), (18939904, 1024, 1), 0); del buf135  # reuse
        # Source Nodes: [x_82], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_38.run(buf136, arg63_1, 18939904, grid=grid(18939904), stream=stream0)
        del arg63_1
        buf137 = reinterpret_tensor(buf134, (256, 1024), (1024, 1), 0); del buf134  # reuse
        # Source Nodes: [x_84], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg64_1, buf137, 262144, grid=grid(262144), stream=stream0)
        del arg64_1
        buf138 = reinterpret_tensor(buf133, (18496, 256), (256, 1), 0); del buf133  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf136, (18496, 1024), (1024, 1), 0), reinterpret_tensor(buf137, (1024, 256), (1, 1024), 0), out=buf138)
        del buf136
        buf139 = reinterpret_tensor(buf138, (1, 18496, 256), (4734976, 256, 1), 0); del buf138  # reuse
        buf143 = buf107; del buf107  # reuse
        buf144 = buf106; del buf106  # reuse
        # Source Nodes: [x_80, x_86, x_out_1], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_47.run(buf139, buf105, buf128, arg59_1, arg65_1, buf143, buf144, 18496, 256, grid=grid(18496), stream=stream0)
        del arg59_1
        del arg65_1
        del buf128
        buf147 = reinterpret_tensor(buf105, (1, 4624, 1024), (4734976, 1024, 1), 0); del buf105  # reuse
        # Source Nodes: [x_90, x_91], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_red_fused__to_copy_native_layer_norm_48.run(buf139, arg66_1, arg67_1, buf147, 4624, 1024, grid=grid(4624), stream=stream0)
        del arg66_1
        del arg67_1
        buf148 = empty_strided_cuda((512, 1024), (1024, 1), torch.float16)
        # Source Nodes: [x_91], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_49.run(arg68_1, buf148, 524288, grid=grid(524288), stream=stream0)
        del arg68_1
        buf149 = empty_strided_cuda((4624, 512), (512, 1), torch.float16)
        # Source Nodes: [x_91], Original ATen: [aten.mm]
        extern_kernels.mm(reinterpret_tensor(buf147, (4624, 1024), (1024, 1), 0), reinterpret_tensor(buf148, (1024, 512), (1, 1024), 0), out=buf149)
        del buf147
        del buf148
        buf150 = empty_strided_cuda((1, 4624, 1), (4640, 1, 4640), torch.float32)
        buf151 = empty_strided_cuda((1, 4624, 1), (4640, 1, 4640), torch.float32)
        # Source Nodes: [x_93], Original ATen: [aten._to_copy, aten.native_layer_norm]
        triton_per_fused__to_copy_native_layer_norm_50.run(buf149, buf150, buf151, 4624, 512, grid=grid(4624), stream=stream0)
        buf153 = empty_strided_cuda((100, 49, 512), (25088, 512, 1), torch.float16)
        # Source Nodes: [linear_18], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf149, buf150, buf151, arg71_1, arg72_1, buf153, 2508800, grid=grid(2508800), stream=stream0)
        del arg71_1
        del arg72_1
        buf154 = empty_strided_cuda((1536, 512), (512, 1), torch.float16)
        # Source Nodes: [linear_18], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg73_1, buf154, 786432, grid=grid(786432), stream=stream0)
        del arg73_1
        buf155 = empty_strided_cuda((4900, 1536), (1536, 1), torch.float16)
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf153, (4900, 512), (512, 1), 0), reinterpret_tensor(buf154, (512, 1536), (1, 512), 0), out=buf155)
        buf156 = empty_strided_cuda((100, 16, 49, 32), (25600, 1600, 32, 1), torch.float16)
        # Source Nodes: [attn_20, q_9], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf155, arg74_1, buf156, 2508800, grid=grid(2508800), stream=stream0)
        buf157 = empty_strided_cuda((100, 16, 32, 49), (25600, 1600, 49, 1), torch.float16)
        # Source Nodes: [attn_20], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf155, arg74_1, buf157, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf158 = empty_strided_cuda((1600, 49, 49), (2401, 49, 1), torch.float16)
        # Source Nodes: [attn_20], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf156, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf157, (1600, 32, 49), (1600, 49, 1), 0), out=buf158)
        buf161 = empty_strided_cuda((100, 16, 49, 49), (38912, 2432, 49, 1), torch.float16)
        # Source Nodes: [attn_21, attn_22, matmul_9], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf158, arg76_1, arg75_1, buf161, 78400, 49, grid=grid(78400), stream=stream0)
        del arg75_1
        del arg76_1
        buf162 = reinterpret_tensor(buf157, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf157  # reuse
        # Source Nodes: [matmul_9], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf155, arg74_1, buf162, 2508800, grid=grid(2508800), stream=stream0)
        del arg74_1
        buf163 = reinterpret_tensor(buf153, (1600, 49, 32), (1568, 32, 1), 0); del buf153  # reuse
        # Source Nodes: [matmul_9], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf161, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf162, (1600, 49, 32), (1600, 32, 1), 0), out=buf163)
        buf164 = empty_strided_cuda((100, 49, 16, 32), (25088, 512, 32, 1), torch.float16)
        # Source Nodes: [x_97], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf163, buf164, 2508800, grid=grid(2508800), stream=stream0)
        buf165 = reinterpret_tensor(buf137, (512, 512), (512, 1), 0); del buf137  # reuse
        # Source Nodes: [x_98], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg77_1, buf165, 262144, grid=grid(262144), stream=stream0)
        del arg77_1
        buf166 = reinterpret_tensor(buf163, (4900, 512), (512, 1), 0); del buf163  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf164, (4900, 512), (512, 1), 0), reinterpret_tensor(buf165, (512, 512), (1, 512), 0), out=buf166)
        buf170 = empty_strided_cuda((1, 4624, 512), (2367488, 512, 1), torch.float16)
        # Source Nodes: [layer_norm_14, x_104, x_105], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf149, buf166, arg78_1, arg79_1, arg80_1, buf170, 4624, 512, grid=grid(4624), stream=stream0)
        del arg79_1
        del arg80_1
        buf171 = empty_strided_cuda((2048, 512), (512, 1), torch.float16)
        # Source Nodes: [x_105], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg81_1, buf171, 1048576, grid=grid(1048576), stream=stream0)
        del arg81_1
        buf172 = reinterpret_tensor(buf76, (4624, 2048), (2048, 1), 0); del buf76  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf170, (4624, 512), (512, 1), 0), reinterpret_tensor(buf171, (512, 2048), (1, 512), 0), out=buf172)
        buf173 = reinterpret_tensor(buf172, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf172  # reuse
        # Source Nodes: [x_106], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf173, arg82_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg82_1
        buf174 = reinterpret_tensor(buf171, (512, 2048), (2048, 1), 0); del buf171  # reuse
        # Source Nodes: [x_108], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg83_1, buf174, 1048576, grid=grid(1048576), stream=stream0)
        del arg83_1
        buf175 = reinterpret_tensor(buf170, (4624, 512), (512, 1), 0); del buf170  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf173, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf174, (2048, 512), (1, 2048), 0), out=buf175)
        buf176 = reinterpret_tensor(buf175, (1, 4624, 512), (2367488, 512, 1), 0); del buf175  # reuse
        buf177 = buf151; del buf151  # reuse
        buf178 = buf150; del buf150  # reuse
        # Source Nodes: [x_104, x_110, x_111], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf176, buf149, buf166, arg78_1, arg84_1, buf177, buf178, 4624, 512, grid=grid(4624), stream=stream0)
        del arg78_1
        del arg84_1
        buf180 = empty_strided_cuda((1, 70, 70, 512), (2508800, 35840, 512, 1), torch.float32)
        # Source Nodes: [shifted_x_2, x_113], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf176, buf177, buf178, arg85_1, arg86_1, buf180, 2508800, grid=grid(2508800), stream=stream0)
        del arg85_1
        del arg86_1
        buf181 = reinterpret_tensor(buf166, (100, 49, 512), (25088, 512, 1), 0); del buf166  # reuse
        # Source Nodes: [linear_22], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf180, buf181, 2508800, grid=grid(2508800), stream=stream0)
        buf182 = buf154; del buf154  # reuse
        # Source Nodes: [linear_22], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg87_1, buf182, 786432, grid=grid(786432), stream=stream0)
        del arg87_1
        buf183 = buf155; del buf155  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf181, (4900, 512), (512, 1), 0), reinterpret_tensor(buf182, (512, 1536), (1, 512), 0), out=buf183)
        buf184 = buf162; del buf162  # reuse
        # Source Nodes: [attn_24, q_11], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf183, arg88_1, buf184, 2508800, grid=grid(2508800), stream=stream0)
        buf185 = reinterpret_tensor(buf156, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf156  # reuse
        # Source Nodes: [attn_24], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf183, arg88_1, buf185, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf186 = buf158; del buf158  # reuse
        # Source Nodes: [attn_24], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf184, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf185, (1600, 32, 49), (1600, 49, 1), 0), out=buf186)
        buf187 = empty_strided_cuda((1, 70, 70, 1), (4928, 70, 1, 1), torch.float32)
        buf188 = reinterpret_tensor(buf187, (1, 70, 70, 1), (4928, 70, 1, 4928), 0); del buf187  # reuse
        # Source Nodes: [img_mask_2, setitem_18, setitem_19, setitem_20, setitem_21, setitem_22], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
        triton_poi_fused_fill_lift_fresh_slice_zeros_64.run(buf188, 4900, grid=grid(4900), stream=stream0)
        buf189 = empty_strided_cuda((1, 3, 70, 1), (210, 70, 1, 210), torch.float32)
        # Source Nodes: [setitem_26], Original ATen: [aten.fill, aten.lift_fresh]
        triton_poi_fused_fill_lift_fresh_65.run(buf188, buf189, 210, grid=grid(210), stream=stream0)
        buf190 = reinterpret_tensor(buf188, (1, 70, 70, 1), (4928, 70, 1, 1), 0); del buf188  # reuse
        # Source Nodes: [setitem_23, setitem_24, setitem_25], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
        triton_poi_fused_fill_lift_fresh_slice_66.run(buf190, buf189, 4900, grid=grid(4900), stream=stream0)
        del buf189
        buf194 = buf161; del buf161  # reuse
        # Source Nodes: [attn_26, attn_28, matmul_11], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf186, arg90_1, arg89_1, buf190, buf194, 78400, 49, grid=grid(78400), stream=stream0)
        del arg89_1
        del arg90_1
        buf195 = reinterpret_tensor(buf185, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf185  # reuse
        # Source Nodes: [matmul_11], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf183, arg88_1, buf195, 2508800, grid=grid(2508800), stream=stream0)
        del arg88_1
        buf196 = reinterpret_tensor(buf181, (1600, 49, 32), (1568, 32, 1), 0); del buf181  # reuse
        # Source Nodes: [matmul_11], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf194, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf195, (1600, 49, 32), (1600, 32, 1), 0), out=buf196)
        buf197 = buf164; del buf164  # reuse
        # Source Nodes: [x_115], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf196, buf197, 2508800, grid=grid(2508800), stream=stream0)
        buf198 = buf165; del buf165  # reuse
        # Source Nodes: [x_116], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg91_1, buf198, 262144, grid=grid(262144), stream=stream0)
        del arg91_1
        buf199 = reinterpret_tensor(buf196, (4900, 512), (512, 1), 0); del buf196  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf197, (4900, 512), (512, 1), 0), reinterpret_tensor(buf198, (512, 512), (1, 512), 0), out=buf199)
        buf204 = reinterpret_tensor(buf149, (1, 4624, 512), (2367488, 512, 1), 0); del buf149  # reuse
        # Source Nodes: [layer_norm_16, x_123, x_124], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf176, buf199, arg92_1, arg93_1, arg94_1, buf204, 4624, 512, grid=grid(4624), stream=stream0)
        del arg93_1
        del arg94_1
        buf205 = reinterpret_tensor(buf174, (2048, 512), (512, 1), 0); del buf174  # reuse
        # Source Nodes: [x_124], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg95_1, buf205, 1048576, grid=grid(1048576), stream=stream0)
        del arg95_1
        buf206 = reinterpret_tensor(buf173, (4624, 2048), (2048, 1), 0); del buf173  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf204, (4624, 512), (512, 1), 0), reinterpret_tensor(buf205, (512, 2048), (1, 512), 0), out=buf206)
        buf207 = reinterpret_tensor(buf206, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf206  # reuse
        # Source Nodes: [x_125], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf207, arg96_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg96_1
        buf208 = reinterpret_tensor(buf205, (512, 2048), (2048, 1), 0); del buf205  # reuse
        # Source Nodes: [x_127], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg97_1, buf208, 1048576, grid=grid(1048576), stream=stream0)
        del arg97_1
        buf209 = reinterpret_tensor(buf204, (4624, 512), (512, 1), 0); del buf204  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf207, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf208, (2048, 512), (1, 2048), 0), out=buf209)
        buf210 = reinterpret_tensor(buf209, (1, 4624, 512), (2367488, 512, 1), 0); del buf209  # reuse
        buf211 = buf178; del buf178  # reuse
        buf212 = buf177; del buf177  # reuse
        # Source Nodes: [x_123, x_129, x_130], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf210, buf176, buf199, arg92_1, arg98_1, buf211, buf212, 4624, 512, grid=grid(4624), stream=stream0)
        del arg92_1
        del arg98_1
        buf214 = reinterpret_tensor(buf199, (100, 49, 512), (25088, 512, 1), 0); del buf199  # reuse
        # Source Nodes: [linear_26], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf210, buf211, buf212, arg99_1, arg100_1, buf214, 2508800, grid=grid(2508800), stream=stream0)
        del arg100_1
        del arg99_1
        buf215 = buf182; del buf182  # reuse
        # Source Nodes: [linear_26], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg101_1, buf215, 786432, grid=grid(786432), stream=stream0)
        del arg101_1
        buf216 = buf183; del buf183  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf214, (4900, 512), (512, 1), 0), reinterpret_tensor(buf215, (512, 1536), (1, 512), 0), out=buf216)
        buf217 = buf195; del buf195  # reuse
        # Source Nodes: [attn_30, q_13], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf216, arg102_1, buf217, 2508800, grid=grid(2508800), stream=stream0)
        buf218 = reinterpret_tensor(buf184, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf184  # reuse
        # Source Nodes: [attn_30], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf216, arg102_1, buf218, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf219 = buf186; del buf186  # reuse
        # Source Nodes: [attn_30], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf217, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf218, (1600, 32, 49), (1600, 49, 1), 0), out=buf219)
        buf222 = buf194; del buf194  # reuse
        # Source Nodes: [attn_31, attn_32, matmul_13], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf219, arg104_1, arg103_1, buf222, 78400, 49, grid=grid(78400), stream=stream0)
        del arg103_1
        del arg104_1
        buf223 = reinterpret_tensor(buf218, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf218  # reuse
        # Source Nodes: [matmul_13], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf216, arg102_1, buf223, 2508800, grid=grid(2508800), stream=stream0)
        del arg102_1
        buf224 = reinterpret_tensor(buf214, (1600, 49, 32), (1568, 32, 1), 0); del buf214  # reuse
        # Source Nodes: [matmul_13], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf222, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf223, (1600, 49, 32), (1600, 32, 1), 0), out=buf224)
        buf225 = buf197; del buf197  # reuse
        # Source Nodes: [x_134], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf224, buf225, 2508800, grid=grid(2508800), stream=stream0)
        buf226 = buf198; del buf198  # reuse
        # Source Nodes: [x_135], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg105_1, buf226, 262144, grid=grid(262144), stream=stream0)
        del arg105_1
        buf227 = reinterpret_tensor(buf224, (4900, 512), (512, 1), 0); del buf224  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf225, (4900, 512), (512, 1), 0), reinterpret_tensor(buf226, (512, 512), (1, 512), 0), out=buf227)
        buf231 = buf176; del buf176  # reuse
        # Source Nodes: [layer_norm_18, x_141, x_142], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf210, buf227, arg106_1, arg107_1, arg108_1, buf231, 4624, 512, grid=grid(4624), stream=stream0)
        del arg107_1
        del arg108_1
        buf232 = reinterpret_tensor(buf208, (2048, 512), (512, 1), 0); del buf208  # reuse
        # Source Nodes: [x_142], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg109_1, buf232, 1048576, grid=grid(1048576), stream=stream0)
        del arg109_1
        buf233 = reinterpret_tensor(buf207, (4624, 2048), (2048, 1), 0); del buf207  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf231, (4624, 512), (512, 1), 0), reinterpret_tensor(buf232, (512, 2048), (1, 512), 0), out=buf233)
        buf234 = reinterpret_tensor(buf233, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf233  # reuse
        # Source Nodes: [x_143], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf234, arg110_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg110_1
        buf235 = reinterpret_tensor(buf232, (512, 2048), (2048, 1), 0); del buf232  # reuse
        # Source Nodes: [x_145], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg111_1, buf235, 1048576, grid=grid(1048576), stream=stream0)
        del arg111_1
        buf236 = reinterpret_tensor(buf231, (4624, 512), (512, 1), 0); del buf231  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf234, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf235, (2048, 512), (1, 2048), 0), out=buf236)
        buf237 = reinterpret_tensor(buf236, (1, 4624, 512), (2367488, 512, 1), 0); del buf236  # reuse
        buf238 = buf212; del buf212  # reuse
        buf239 = buf211; del buf211  # reuse
        # Source Nodes: [x_141, x_147, x_148], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf237, buf210, buf227, arg106_1, arg112_1, buf238, buf239, 4624, 512, grid=grid(4624), stream=stream0)
        del arg106_1
        del arg112_1
        buf241 = buf180; del buf180  # reuse
        # Source Nodes: [shifted_x_3, x_150], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf237, buf238, buf239, arg113_1, arg114_1, buf241, 2508800, grid=grid(2508800), stream=stream0)
        del arg113_1
        del arg114_1
        buf242 = reinterpret_tensor(buf227, (100, 49, 512), (25088, 512, 1), 0); del buf227  # reuse
        # Source Nodes: [linear_30], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf241, buf242, 2508800, grid=grid(2508800), stream=stream0)
        buf243 = buf215; del buf215  # reuse
        # Source Nodes: [linear_30], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg115_1, buf243, 786432, grid=grid(786432), stream=stream0)
        del arg115_1
        buf244 = buf216; del buf216  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf242, (4900, 512), (512, 1), 0), reinterpret_tensor(buf243, (512, 1536), (1, 512), 0), out=buf244)
        buf245 = buf223; del buf223  # reuse
        # Source Nodes: [attn_34, q_15], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf244, arg116_1, buf245, 2508800, grid=grid(2508800), stream=stream0)
        buf246 = reinterpret_tensor(buf217, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf217  # reuse
        # Source Nodes: [attn_34], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf244, arg116_1, buf246, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf247 = buf219; del buf219  # reuse
        # Source Nodes: [attn_34], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf245, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf246, (1600, 32, 49), (1600, 49, 1), 0), out=buf247)
        buf251 = buf222; del buf222  # reuse
        # Source Nodes: [attn_36, attn_38, matmul_15], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf247, arg118_1, arg117_1, buf190, buf251, 78400, 49, grid=grid(78400), stream=stream0)
        del arg117_1
        del arg118_1
        buf252 = reinterpret_tensor(buf246, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf246  # reuse
        # Source Nodes: [matmul_15], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf244, arg116_1, buf252, 2508800, grid=grid(2508800), stream=stream0)
        del arg116_1
        buf253 = reinterpret_tensor(buf242, (1600, 49, 32), (1568, 32, 1), 0); del buf242  # reuse
        # Source Nodes: [matmul_15], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf251, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf252, (1600, 49, 32), (1600, 32, 1), 0), out=buf253)
        buf254 = buf225; del buf225  # reuse
        # Source Nodes: [x_152], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf253, buf254, 2508800, grid=grid(2508800), stream=stream0)
        buf255 = buf226; del buf226  # reuse
        # Source Nodes: [x_153], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg119_1, buf255, 262144, grid=grid(262144), stream=stream0)
        del arg119_1
        buf256 = reinterpret_tensor(buf253, (4900, 512), (512, 1), 0); del buf253  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf254, (4900, 512), (512, 1), 0), reinterpret_tensor(buf255, (512, 512), (1, 512), 0), out=buf256)
        buf261 = buf210; del buf210  # reuse
        # Source Nodes: [layer_norm_20, x_160, x_161], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf237, buf256, arg120_1, arg121_1, arg122_1, buf261, 4624, 512, grid=grid(4624), stream=stream0)
        del arg121_1
        del arg122_1
        buf262 = reinterpret_tensor(buf235, (2048, 512), (512, 1), 0); del buf235  # reuse
        # Source Nodes: [x_161], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg123_1, buf262, 1048576, grid=grid(1048576), stream=stream0)
        del arg123_1
        buf263 = reinterpret_tensor(buf234, (4624, 2048), (2048, 1), 0); del buf234  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf261, (4624, 512), (512, 1), 0), reinterpret_tensor(buf262, (512, 2048), (1, 512), 0), out=buf263)
        buf264 = reinterpret_tensor(buf263, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf263  # reuse
        # Source Nodes: [x_162], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf264, arg124_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg124_1
        buf265 = reinterpret_tensor(buf262, (512, 2048), (2048, 1), 0); del buf262  # reuse
        # Source Nodes: [x_164], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg125_1, buf265, 1048576, grid=grid(1048576), stream=stream0)
        del arg125_1
        buf266 = reinterpret_tensor(buf261, (4624, 512), (512, 1), 0); del buf261  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf264, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf265, (2048, 512), (1, 2048), 0), out=buf266)
        buf267 = reinterpret_tensor(buf266, (1, 4624, 512), (2367488, 512, 1), 0); del buf266  # reuse
        buf268 = buf239; del buf239  # reuse
        buf269 = buf238; del buf238  # reuse
        # Source Nodes: [x_160, x_166, x_167], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf267, buf237, buf256, arg120_1, arg126_1, buf268, buf269, 4624, 512, grid=grid(4624), stream=stream0)
        del arg120_1
        del arg126_1
        buf271 = reinterpret_tensor(buf256, (100, 49, 512), (25088, 512, 1), 0); del buf256  # reuse
        # Source Nodes: [linear_34], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf267, buf268, buf269, arg127_1, arg128_1, buf271, 2508800, grid=grid(2508800), stream=stream0)
        del arg127_1
        del arg128_1
        buf272 = buf243; del buf243  # reuse
        # Source Nodes: [linear_34], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg129_1, buf272, 786432, grid=grid(786432), stream=stream0)
        del arg129_1
        buf273 = buf244; del buf244  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf271, (4900, 512), (512, 1), 0), reinterpret_tensor(buf272, (512, 1536), (1, 512), 0), out=buf273)
        buf274 = buf252; del buf252  # reuse
        # Source Nodes: [attn_40, q_17], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf273, arg130_1, buf274, 2508800, grid=grid(2508800), stream=stream0)
        buf275 = reinterpret_tensor(buf245, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf245  # reuse
        # Source Nodes: [attn_40], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf273, arg130_1, buf275, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf276 = buf247; del buf247  # reuse
        # Source Nodes: [attn_40], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf274, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf275, (1600, 32, 49), (1600, 49, 1), 0), out=buf276)
        buf279 = buf251; del buf251  # reuse
        # Source Nodes: [attn_41, attn_42, matmul_17], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf276, arg132_1, arg131_1, buf279, 78400, 49, grid=grid(78400), stream=stream0)
        del arg131_1
        del arg132_1
        buf280 = reinterpret_tensor(buf275, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf275  # reuse
        # Source Nodes: [matmul_17], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf273, arg130_1, buf280, 2508800, grid=grid(2508800), stream=stream0)
        del arg130_1
        buf281 = reinterpret_tensor(buf271, (1600, 49, 32), (1568, 32, 1), 0); del buf271  # reuse
        # Source Nodes: [matmul_17], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf279, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf280, (1600, 49, 32), (1600, 32, 1), 0), out=buf281)
        buf282 = buf254; del buf254  # reuse
        # Source Nodes: [x_171], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf281, buf282, 2508800, grid=grid(2508800), stream=stream0)
        buf283 = buf255; del buf255  # reuse
        # Source Nodes: [x_172], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg133_1, buf283, 262144, grid=grid(262144), stream=stream0)
        del arg133_1
        buf284 = reinterpret_tensor(buf281, (4900, 512), (512, 1), 0); del buf281  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf282, (4900, 512), (512, 1), 0), reinterpret_tensor(buf283, (512, 512), (1, 512), 0), out=buf284)
        buf288 = buf237; del buf237  # reuse
        # Source Nodes: [layer_norm_22, x_178, x_179], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf267, buf284, arg134_1, arg135_1, arg136_1, buf288, 4624, 512, grid=grid(4624), stream=stream0)
        del arg135_1
        del arg136_1
        buf289 = reinterpret_tensor(buf265, (2048, 512), (512, 1), 0); del buf265  # reuse
        # Source Nodes: [x_179], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg137_1, buf289, 1048576, grid=grid(1048576), stream=stream0)
        del arg137_1
        buf290 = reinterpret_tensor(buf264, (4624, 2048), (2048, 1), 0); del buf264  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf288, (4624, 512), (512, 1), 0), reinterpret_tensor(buf289, (512, 2048), (1, 512), 0), out=buf290)
        buf291 = reinterpret_tensor(buf290, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf290  # reuse
        # Source Nodes: [x_180], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf291, arg138_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg138_1
        buf292 = reinterpret_tensor(buf289, (512, 2048), (2048, 1), 0); del buf289  # reuse
        # Source Nodes: [x_182], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg139_1, buf292, 1048576, grid=grid(1048576), stream=stream0)
        del arg139_1
        buf293 = reinterpret_tensor(buf288, (4624, 512), (512, 1), 0); del buf288  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf291, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf292, (2048, 512), (1, 2048), 0), out=buf293)
        buf294 = reinterpret_tensor(buf293, (1, 4624, 512), (2367488, 512, 1), 0); del buf293  # reuse
        buf295 = buf269; del buf269  # reuse
        buf296 = buf268; del buf268  # reuse
        # Source Nodes: [x_178, x_184, x_185], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf294, buf267, buf284, arg134_1, arg140_1, buf295, buf296, 4624, 512, grid=grid(4624), stream=stream0)
        del arg134_1
        del arg140_1
        buf298 = buf241; del buf241  # reuse
        # Source Nodes: [shifted_x_4, x_187], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf294, buf295, buf296, arg141_1, arg142_1, buf298, 2508800, grid=grid(2508800), stream=stream0)
        del arg141_1
        del arg142_1
        buf299 = reinterpret_tensor(buf284, (100, 49, 512), (25088, 512, 1), 0); del buf284  # reuse
        # Source Nodes: [linear_38], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf298, buf299, 2508800, grid=grid(2508800), stream=stream0)
        buf300 = buf272; del buf272  # reuse
        # Source Nodes: [linear_38], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg143_1, buf300, 786432, grid=grid(786432), stream=stream0)
        del arg143_1
        buf301 = buf273; del buf273  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf299, (4900, 512), (512, 1), 0), reinterpret_tensor(buf300, (512, 1536), (1, 512), 0), out=buf301)
        buf302 = buf280; del buf280  # reuse
        # Source Nodes: [attn_44, q_19], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf301, arg144_1, buf302, 2508800, grid=grid(2508800), stream=stream0)
        buf303 = reinterpret_tensor(buf274, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf274  # reuse
        # Source Nodes: [attn_44], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf301, arg144_1, buf303, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf304 = buf276; del buf276  # reuse
        # Source Nodes: [attn_44], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf302, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf303, (1600, 32, 49), (1600, 49, 1), 0), out=buf304)
        buf308 = buf279; del buf279  # reuse
        # Source Nodes: [attn_46, attn_48, matmul_19], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf304, arg146_1, arg145_1, buf190, buf308, 78400, 49, grid=grid(78400), stream=stream0)
        del arg145_1
        del arg146_1
        buf309 = reinterpret_tensor(buf303, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf303  # reuse
        # Source Nodes: [matmul_19], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf301, arg144_1, buf309, 2508800, grid=grid(2508800), stream=stream0)
        del arg144_1
        buf310 = reinterpret_tensor(buf299, (1600, 49, 32), (1568, 32, 1), 0); del buf299  # reuse
        # Source Nodes: [matmul_19], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf308, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf309, (1600, 49, 32), (1600, 32, 1), 0), out=buf310)
        buf311 = buf282; del buf282  # reuse
        # Source Nodes: [x_189], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf310, buf311, 2508800, grid=grid(2508800), stream=stream0)
        buf312 = buf283; del buf283  # reuse
        # Source Nodes: [x_190], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg147_1, buf312, 262144, grid=grid(262144), stream=stream0)
        del arg147_1
        buf313 = reinterpret_tensor(buf310, (4900, 512), (512, 1), 0); del buf310  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf311, (4900, 512), (512, 1), 0), reinterpret_tensor(buf312, (512, 512), (1, 512), 0), out=buf313)
        buf318 = buf267; del buf267  # reuse
        # Source Nodes: [layer_norm_24, x_197, x_198], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf294, buf313, arg148_1, arg149_1, arg150_1, buf318, 4624, 512, grid=grid(4624), stream=stream0)
        del arg149_1
        del arg150_1
        buf319 = reinterpret_tensor(buf292, (2048, 512), (512, 1), 0); del buf292  # reuse
        # Source Nodes: [x_198], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg151_1, buf319, 1048576, grid=grid(1048576), stream=stream0)
        del arg151_1
        buf320 = reinterpret_tensor(buf291, (4624, 2048), (2048, 1), 0); del buf291  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf318, (4624, 512), (512, 1), 0), reinterpret_tensor(buf319, (512, 2048), (1, 512), 0), out=buf320)
        buf321 = reinterpret_tensor(buf320, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf320  # reuse
        # Source Nodes: [x_199], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf321, arg152_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg152_1
        buf322 = reinterpret_tensor(buf319, (512, 2048), (2048, 1), 0); del buf319  # reuse
        # Source Nodes: [x_201], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg153_1, buf322, 1048576, grid=grid(1048576), stream=stream0)
        del arg153_1
        buf323 = reinterpret_tensor(buf318, (4624, 512), (512, 1), 0); del buf318  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf321, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf322, (2048, 512), (1, 2048), 0), out=buf323)
        buf324 = reinterpret_tensor(buf323, (1, 4624, 512), (2367488, 512, 1), 0); del buf323  # reuse
        buf325 = buf296; del buf296  # reuse
        buf326 = buf295; del buf295  # reuse
        # Source Nodes: [x_197, x_203, x_204], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf324, buf294, buf313, arg148_1, arg154_1, buf325, buf326, 4624, 512, grid=grid(4624), stream=stream0)
        del arg148_1
        del arg154_1
        buf328 = reinterpret_tensor(buf313, (100, 49, 512), (25088, 512, 1), 0); del buf313  # reuse
        # Source Nodes: [linear_42], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf324, buf325, buf326, arg155_1, arg156_1, buf328, 2508800, grid=grid(2508800), stream=stream0)
        del arg155_1
        del arg156_1
        buf329 = buf300; del buf300  # reuse
        # Source Nodes: [linear_42], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg157_1, buf329, 786432, grid=grid(786432), stream=stream0)
        del arg157_1
        buf330 = buf301; del buf301  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf328, (4900, 512), (512, 1), 0), reinterpret_tensor(buf329, (512, 1536), (1, 512), 0), out=buf330)
        buf331 = buf309; del buf309  # reuse
        # Source Nodes: [attn_50, q_21], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf330, arg158_1, buf331, 2508800, grid=grid(2508800), stream=stream0)
        buf332 = reinterpret_tensor(buf302, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf302  # reuse
        # Source Nodes: [attn_50], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf330, arg158_1, buf332, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf333 = buf304; del buf304  # reuse
        # Source Nodes: [attn_50], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf331, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf332, (1600, 32, 49), (1600, 49, 1), 0), out=buf333)
        buf336 = buf308; del buf308  # reuse
        # Source Nodes: [attn_51, attn_52, matmul_21], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf333, arg160_1, arg159_1, buf336, 78400, 49, grid=grid(78400), stream=stream0)
        del arg159_1
        del arg160_1
        buf337 = reinterpret_tensor(buf332, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf332  # reuse
        # Source Nodes: [matmul_21], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf330, arg158_1, buf337, 2508800, grid=grid(2508800), stream=stream0)
        del arg158_1
        buf338 = reinterpret_tensor(buf328, (1600, 49, 32), (1568, 32, 1), 0); del buf328  # reuse
        # Source Nodes: [matmul_21], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf336, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf337, (1600, 49, 32), (1600, 32, 1), 0), out=buf338)
        buf339 = buf311; del buf311  # reuse
        # Source Nodes: [x_208], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf338, buf339, 2508800, grid=grid(2508800), stream=stream0)
        buf340 = buf312; del buf312  # reuse
        # Source Nodes: [x_209], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg161_1, buf340, 262144, grid=grid(262144), stream=stream0)
        del arg161_1
        buf341 = reinterpret_tensor(buf338, (4900, 512), (512, 1), 0); del buf338  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf339, (4900, 512), (512, 1), 0), reinterpret_tensor(buf340, (512, 512), (1, 512), 0), out=buf341)
        buf345 = buf294; del buf294  # reuse
        # Source Nodes: [layer_norm_26, x_215, x_216], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf324, buf341, arg162_1, arg163_1, arg164_1, buf345, 4624, 512, grid=grid(4624), stream=stream0)
        del arg163_1
        del arg164_1
        buf346 = reinterpret_tensor(buf322, (2048, 512), (512, 1), 0); del buf322  # reuse
        # Source Nodes: [x_216], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg165_1, buf346, 1048576, grid=grid(1048576), stream=stream0)
        del arg165_1
        buf347 = reinterpret_tensor(buf321, (4624, 2048), (2048, 1), 0); del buf321  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf345, (4624, 512), (512, 1), 0), reinterpret_tensor(buf346, (512, 2048), (1, 512), 0), out=buf347)
        buf348 = reinterpret_tensor(buf347, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf347  # reuse
        # Source Nodes: [x_217], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf348, arg166_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg166_1
        buf349 = reinterpret_tensor(buf346, (512, 2048), (2048, 1), 0); del buf346  # reuse
        # Source Nodes: [x_219], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg167_1, buf349, 1048576, grid=grid(1048576), stream=stream0)
        del arg167_1
        buf350 = reinterpret_tensor(buf345, (4624, 512), (512, 1), 0); del buf345  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf348, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf349, (2048, 512), (1, 2048), 0), out=buf350)
        buf351 = reinterpret_tensor(buf350, (1, 4624, 512), (2367488, 512, 1), 0); del buf350  # reuse
        buf352 = buf326; del buf326  # reuse
        buf353 = buf325; del buf325  # reuse
        # Source Nodes: [x_215, x_221, x_222], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf351, buf324, buf341, arg162_1, arg168_1, buf352, buf353, 4624, 512, grid=grid(4624), stream=stream0)
        del arg162_1
        del arg168_1
        buf355 = buf298; del buf298  # reuse
        # Source Nodes: [shifted_x_5, x_224], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf351, buf352, buf353, arg169_1, arg170_1, buf355, 2508800, grid=grid(2508800), stream=stream0)
        del arg169_1
        del arg170_1
        buf356 = reinterpret_tensor(buf341, (100, 49, 512), (25088, 512, 1), 0); del buf341  # reuse
        # Source Nodes: [linear_46], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf355, buf356, 2508800, grid=grid(2508800), stream=stream0)
        buf357 = buf329; del buf329  # reuse
        # Source Nodes: [linear_46], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg171_1, buf357, 786432, grid=grid(786432), stream=stream0)
        del arg171_1
        buf358 = buf330; del buf330  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf356, (4900, 512), (512, 1), 0), reinterpret_tensor(buf357, (512, 1536), (1, 512), 0), out=buf358)
        buf359 = buf337; del buf337  # reuse
        # Source Nodes: [attn_54, q_23], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf358, arg172_1, buf359, 2508800, grid=grid(2508800), stream=stream0)
        buf360 = reinterpret_tensor(buf331, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf331  # reuse
        # Source Nodes: [attn_54], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf358, arg172_1, buf360, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf361 = buf333; del buf333  # reuse
        # Source Nodes: [attn_54], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf359, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf360, (1600, 32, 49), (1600, 49, 1), 0), out=buf361)
        buf365 = buf336; del buf336  # reuse
        # Source Nodes: [attn_56, attn_58, matmul_23], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf361, arg174_1, arg173_1, buf190, buf365, 78400, 49, grid=grid(78400), stream=stream0)
        del arg173_1
        del arg174_1
        buf366 = reinterpret_tensor(buf360, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf360  # reuse
        # Source Nodes: [matmul_23], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf358, arg172_1, buf366, 2508800, grid=grid(2508800), stream=stream0)
        del arg172_1
        buf367 = reinterpret_tensor(buf356, (1600, 49, 32), (1568, 32, 1), 0); del buf356  # reuse
        # Source Nodes: [matmul_23], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf365, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf366, (1600, 49, 32), (1600, 32, 1), 0), out=buf367)
        buf368 = buf339; del buf339  # reuse
        # Source Nodes: [x_226], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf367, buf368, 2508800, grid=grid(2508800), stream=stream0)
        buf369 = buf340; del buf340  # reuse
        # Source Nodes: [x_227], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg175_1, buf369, 262144, grid=grid(262144), stream=stream0)
        del arg175_1
        buf370 = reinterpret_tensor(buf367, (4900, 512), (512, 1), 0); del buf367  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf368, (4900, 512), (512, 1), 0), reinterpret_tensor(buf369, (512, 512), (1, 512), 0), out=buf370)
        buf375 = buf324; del buf324  # reuse
        # Source Nodes: [layer_norm_28, x_234, x_235], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf351, buf370, arg176_1, arg177_1, arg178_1, buf375, 4624, 512, grid=grid(4624), stream=stream0)
        del arg177_1
        del arg178_1
        buf376 = reinterpret_tensor(buf349, (2048, 512), (512, 1), 0); del buf349  # reuse
        # Source Nodes: [x_235], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg179_1, buf376, 1048576, grid=grid(1048576), stream=stream0)
        del arg179_1
        buf377 = reinterpret_tensor(buf348, (4624, 2048), (2048, 1), 0); del buf348  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf375, (4624, 512), (512, 1), 0), reinterpret_tensor(buf376, (512, 2048), (1, 512), 0), out=buf377)
        buf378 = reinterpret_tensor(buf377, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf377  # reuse
        # Source Nodes: [x_236], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf378, arg180_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg180_1
        buf379 = reinterpret_tensor(buf376, (512, 2048), (2048, 1), 0); del buf376  # reuse
        # Source Nodes: [x_238], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg181_1, buf379, 1048576, grid=grid(1048576), stream=stream0)
        del arg181_1
        buf380 = reinterpret_tensor(buf375, (4624, 512), (512, 1), 0); del buf375  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf378, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf379, (2048, 512), (1, 2048), 0), out=buf380)
        buf381 = reinterpret_tensor(buf380, (1, 4624, 512), (2367488, 512, 1), 0); del buf380  # reuse
        buf382 = buf353; del buf353  # reuse
        buf383 = buf352; del buf352  # reuse
        # Source Nodes: [x_234, x_240, x_241], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf381, buf351, buf370, arg176_1, arg182_1, buf382, buf383, 4624, 512, grid=grid(4624), stream=stream0)
        del arg176_1
        del arg182_1
        buf385 = reinterpret_tensor(buf370, (100, 49, 512), (25088, 512, 1), 0); del buf370  # reuse
        # Source Nodes: [linear_50], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf381, buf382, buf383, arg183_1, arg184_1, buf385, 2508800, grid=grid(2508800), stream=stream0)
        del arg183_1
        del arg184_1
        buf386 = buf357; del buf357  # reuse
        # Source Nodes: [linear_50], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg185_1, buf386, 786432, grid=grid(786432), stream=stream0)
        del arg185_1
        buf387 = buf358; del buf358  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf385, (4900, 512), (512, 1), 0), reinterpret_tensor(buf386, (512, 1536), (1, 512), 0), out=buf387)
        buf388 = buf366; del buf366  # reuse
        # Source Nodes: [attn_60, q_25], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf387, arg186_1, buf388, 2508800, grid=grid(2508800), stream=stream0)
        buf389 = reinterpret_tensor(buf359, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf359  # reuse
        # Source Nodes: [attn_60], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf387, arg186_1, buf389, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf390 = buf361; del buf361  # reuse
        # Source Nodes: [attn_60], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf388, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf389, (1600, 32, 49), (1600, 49, 1), 0), out=buf390)
        buf393 = buf365; del buf365  # reuse
        # Source Nodes: [attn_61, attn_62, matmul_25], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf390, arg188_1, arg187_1, buf393, 78400, 49, grid=grid(78400), stream=stream0)
        del arg187_1
        del arg188_1
        buf394 = reinterpret_tensor(buf389, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf389  # reuse
        # Source Nodes: [matmul_25], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf387, arg186_1, buf394, 2508800, grid=grid(2508800), stream=stream0)
        del arg186_1
        buf395 = reinterpret_tensor(buf385, (1600, 49, 32), (1568, 32, 1), 0); del buf385  # reuse
        # Source Nodes: [matmul_25], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf393, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf394, (1600, 49, 32), (1600, 32, 1), 0), out=buf395)
        buf396 = buf368; del buf368  # reuse
        # Source Nodes: [x_245], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf395, buf396, 2508800, grid=grid(2508800), stream=stream0)
        buf397 = buf369; del buf369  # reuse
        # Source Nodes: [x_246], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg189_1, buf397, 262144, grid=grid(262144), stream=stream0)
        del arg189_1
        buf398 = reinterpret_tensor(buf395, (4900, 512), (512, 1), 0); del buf395  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf396, (4900, 512), (512, 1), 0), reinterpret_tensor(buf397, (512, 512), (1, 512), 0), out=buf398)
        buf402 = buf351; del buf351  # reuse
        # Source Nodes: [layer_norm_30, x_252, x_253], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf381, buf398, arg190_1, arg191_1, arg192_1, buf402, 4624, 512, grid=grid(4624), stream=stream0)
        del arg191_1
        del arg192_1
        buf403 = reinterpret_tensor(buf379, (2048, 512), (512, 1), 0); del buf379  # reuse
        # Source Nodes: [x_253], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg193_1, buf403, 1048576, grid=grid(1048576), stream=stream0)
        del arg193_1
        buf404 = reinterpret_tensor(buf378, (4624, 2048), (2048, 1), 0); del buf378  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf402, (4624, 512), (512, 1), 0), reinterpret_tensor(buf403, (512, 2048), (1, 512), 0), out=buf404)
        buf405 = reinterpret_tensor(buf404, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf404  # reuse
        # Source Nodes: [x_254], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf405, arg194_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg194_1
        buf406 = reinterpret_tensor(buf403, (512, 2048), (2048, 1), 0); del buf403  # reuse
        # Source Nodes: [x_256], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg195_1, buf406, 1048576, grid=grid(1048576), stream=stream0)
        del arg195_1
        buf407 = reinterpret_tensor(buf402, (4624, 512), (512, 1), 0); del buf402  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf405, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf406, (2048, 512), (1, 2048), 0), out=buf407)
        buf408 = reinterpret_tensor(buf407, (1, 4624, 512), (2367488, 512, 1), 0); del buf407  # reuse
        buf409 = buf383; del buf383  # reuse
        buf410 = buf382; del buf382  # reuse
        # Source Nodes: [x_252, x_258, x_259], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf408, buf381, buf398, arg190_1, arg196_1, buf409, buf410, 4624, 512, grid=grid(4624), stream=stream0)
        del arg190_1
        del arg196_1
        buf412 = buf355; del buf355  # reuse
        # Source Nodes: [shifted_x_6, x_261], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf408, buf409, buf410, arg197_1, arg198_1, buf412, 2508800, grid=grid(2508800), stream=stream0)
        del arg197_1
        del arg198_1
        buf413 = reinterpret_tensor(buf398, (100, 49, 512), (25088, 512, 1), 0); del buf398  # reuse
        # Source Nodes: [linear_54], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf412, buf413, 2508800, grid=grid(2508800), stream=stream0)
        buf414 = buf386; del buf386  # reuse
        # Source Nodes: [linear_54], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg199_1, buf414, 786432, grid=grid(786432), stream=stream0)
        del arg199_1
        buf415 = buf387; del buf387  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf413, (4900, 512), (512, 1), 0), reinterpret_tensor(buf414, (512, 1536), (1, 512), 0), out=buf415)
        buf416 = buf394; del buf394  # reuse
        # Source Nodes: [attn_64, q_27], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf415, arg200_1, buf416, 2508800, grid=grid(2508800), stream=stream0)
        buf417 = reinterpret_tensor(buf388, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf388  # reuse
        # Source Nodes: [attn_64], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf415, arg200_1, buf417, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf418 = buf390; del buf390  # reuse
        # Source Nodes: [attn_64], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf416, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf417, (1600, 32, 49), (1600, 49, 1), 0), out=buf418)
        buf422 = buf393; del buf393  # reuse
        # Source Nodes: [attn_66, attn_68, matmul_27], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf418, arg202_1, arg201_1, buf190, buf422, 78400, 49, grid=grid(78400), stream=stream0)
        del arg201_1
        del arg202_1
        buf423 = reinterpret_tensor(buf417, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf417  # reuse
        # Source Nodes: [matmul_27], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf415, arg200_1, buf423, 2508800, grid=grid(2508800), stream=stream0)
        del arg200_1
        buf424 = reinterpret_tensor(buf413, (1600, 49, 32), (1568, 32, 1), 0); del buf413  # reuse
        # Source Nodes: [matmul_27], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf422, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf423, (1600, 49, 32), (1600, 32, 1), 0), out=buf424)
        buf425 = buf396; del buf396  # reuse
        # Source Nodes: [x_263], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf424, buf425, 2508800, grid=grid(2508800), stream=stream0)
        buf426 = buf397; del buf397  # reuse
        # Source Nodes: [x_264], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg203_1, buf426, 262144, grid=grid(262144), stream=stream0)
        del arg203_1
        buf427 = reinterpret_tensor(buf424, (4900, 512), (512, 1), 0); del buf424  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf425, (4900, 512), (512, 1), 0), reinterpret_tensor(buf426, (512, 512), (1, 512), 0), out=buf427)
        buf432 = buf381; del buf381  # reuse
        # Source Nodes: [layer_norm_32, x_271, x_272], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf408, buf427, arg204_1, arg205_1, arg206_1, buf432, 4624, 512, grid=grid(4624), stream=stream0)
        del arg205_1
        del arg206_1
        buf433 = reinterpret_tensor(buf406, (2048, 512), (512, 1), 0); del buf406  # reuse
        # Source Nodes: [x_272], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg207_1, buf433, 1048576, grid=grid(1048576), stream=stream0)
        del arg207_1
        buf434 = reinterpret_tensor(buf405, (4624, 2048), (2048, 1), 0); del buf405  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf432, (4624, 512), (512, 1), 0), reinterpret_tensor(buf433, (512, 2048), (1, 512), 0), out=buf434)
        buf435 = reinterpret_tensor(buf434, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf434  # reuse
        # Source Nodes: [x_273], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf435, arg208_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg208_1
        buf436 = reinterpret_tensor(buf433, (512, 2048), (2048, 1), 0); del buf433  # reuse
        # Source Nodes: [x_275], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg209_1, buf436, 1048576, grid=grid(1048576), stream=stream0)
        del arg209_1
        buf437 = reinterpret_tensor(buf432, (4624, 512), (512, 1), 0); del buf432  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf435, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf436, (2048, 512), (1, 2048), 0), out=buf437)
        buf438 = reinterpret_tensor(buf437, (1, 4624, 512), (2367488, 512, 1), 0); del buf437  # reuse
        buf439 = buf410; del buf410  # reuse
        buf440 = buf409; del buf409  # reuse
        # Source Nodes: [x_271, x_277, x_278], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf438, buf408, buf427, arg204_1, arg210_1, buf439, buf440, 4624, 512, grid=grid(4624), stream=stream0)
        del arg204_1
        del arg210_1
        buf442 = reinterpret_tensor(buf427, (100, 49, 512), (25088, 512, 1), 0); del buf427  # reuse
        # Source Nodes: [linear_58], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf438, buf439, buf440, arg211_1, arg212_1, buf442, 2508800, grid=grid(2508800), stream=stream0)
        del arg211_1
        del arg212_1
        buf443 = buf414; del buf414  # reuse
        # Source Nodes: [linear_58], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg213_1, buf443, 786432, grid=grid(786432), stream=stream0)
        del arg213_1
        buf444 = buf415; del buf415  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf442, (4900, 512), (512, 1), 0), reinterpret_tensor(buf443, (512, 1536), (1, 512), 0), out=buf444)
        buf445 = buf423; del buf423  # reuse
        # Source Nodes: [attn_70, q_29], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf444, arg214_1, buf445, 2508800, grid=grid(2508800), stream=stream0)
        buf446 = reinterpret_tensor(buf416, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf416  # reuse
        # Source Nodes: [attn_70], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf444, arg214_1, buf446, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf447 = buf418; del buf418  # reuse
        # Source Nodes: [attn_70], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf445, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf446, (1600, 32, 49), (1600, 49, 1), 0), out=buf447)
        buf450 = buf422; del buf422  # reuse
        # Source Nodes: [attn_71, attn_72, matmul_29], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf447, arg216_1, arg215_1, buf450, 78400, 49, grid=grid(78400), stream=stream0)
        del arg215_1
        del arg216_1
        buf451 = reinterpret_tensor(buf446, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf446  # reuse
        # Source Nodes: [matmul_29], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf444, arg214_1, buf451, 2508800, grid=grid(2508800), stream=stream0)
        del arg214_1
        buf452 = reinterpret_tensor(buf442, (1600, 49, 32), (1568, 32, 1), 0); del buf442  # reuse
        # Source Nodes: [matmul_29], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf450, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf451, (1600, 49, 32), (1600, 32, 1), 0), out=buf452)
        buf453 = buf425; del buf425  # reuse
        # Source Nodes: [x_282], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf452, buf453, 2508800, grid=grid(2508800), stream=stream0)
        buf454 = buf426; del buf426  # reuse
        # Source Nodes: [x_283], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg217_1, buf454, 262144, grid=grid(262144), stream=stream0)
        del arg217_1
        buf455 = reinterpret_tensor(buf452, (4900, 512), (512, 1), 0); del buf452  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf453, (4900, 512), (512, 1), 0), reinterpret_tensor(buf454, (512, 512), (1, 512), 0), out=buf455)
        buf459 = buf408; del buf408  # reuse
        # Source Nodes: [layer_norm_34, x_289, x_290], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf438, buf455, arg218_1, arg219_1, arg220_1, buf459, 4624, 512, grid=grid(4624), stream=stream0)
        del arg219_1
        del arg220_1
        buf460 = reinterpret_tensor(buf436, (2048, 512), (512, 1), 0); del buf436  # reuse
        # Source Nodes: [x_290], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg221_1, buf460, 1048576, grid=grid(1048576), stream=stream0)
        del arg221_1
        buf461 = reinterpret_tensor(buf435, (4624, 2048), (2048, 1), 0); del buf435  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf459, (4624, 512), (512, 1), 0), reinterpret_tensor(buf460, (512, 2048), (1, 512), 0), out=buf461)
        buf462 = reinterpret_tensor(buf461, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf461  # reuse
        # Source Nodes: [x_291], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf462, arg222_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg222_1
        buf463 = reinterpret_tensor(buf460, (512, 2048), (2048, 1), 0); del buf460  # reuse
        # Source Nodes: [x_293], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg223_1, buf463, 1048576, grid=grid(1048576), stream=stream0)
        del arg223_1
        buf464 = reinterpret_tensor(buf459, (4624, 512), (512, 1), 0); del buf459  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf462, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf463, (2048, 512), (1, 2048), 0), out=buf464)
        buf465 = reinterpret_tensor(buf464, (1, 4624, 512), (2367488, 512, 1), 0); del buf464  # reuse
        buf466 = buf440; del buf440  # reuse
        buf467 = buf439; del buf439  # reuse
        # Source Nodes: [x_289, x_295, x_296], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf465, buf438, buf455, arg218_1, arg224_1, buf466, buf467, 4624, 512, grid=grid(4624), stream=stream0)
        del arg218_1
        del arg224_1
        buf469 = buf412; del buf412  # reuse
        # Source Nodes: [shifted_x_7, x_298], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf465, buf466, buf467, arg225_1, arg226_1, buf469, 2508800, grid=grid(2508800), stream=stream0)
        del arg225_1
        del arg226_1
        buf470 = reinterpret_tensor(buf455, (100, 49, 512), (25088, 512, 1), 0); del buf455  # reuse
        # Source Nodes: [linear_62], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf469, buf470, 2508800, grid=grid(2508800), stream=stream0)
        buf471 = buf443; del buf443  # reuse
        # Source Nodes: [linear_62], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg227_1, buf471, 786432, grid=grid(786432), stream=stream0)
        del arg227_1
        buf472 = buf444; del buf444  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf470, (4900, 512), (512, 1), 0), reinterpret_tensor(buf471, (512, 1536), (1, 512), 0), out=buf472)
        buf473 = buf451; del buf451  # reuse
        # Source Nodes: [attn_74, q_31], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf472, arg228_1, buf473, 2508800, grid=grid(2508800), stream=stream0)
        buf474 = reinterpret_tensor(buf445, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf445  # reuse
        # Source Nodes: [attn_74], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf472, arg228_1, buf474, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf475 = buf447; del buf447  # reuse
        # Source Nodes: [attn_74], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf473, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf474, (1600, 32, 49), (1600, 49, 1), 0), out=buf475)
        buf479 = buf450; del buf450  # reuse
        # Source Nodes: [attn_76, attn_78, matmul_31], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf475, arg230_1, arg229_1, buf190, buf479, 78400, 49, grid=grid(78400), stream=stream0)
        del arg229_1
        del arg230_1
        buf480 = reinterpret_tensor(buf474, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf474  # reuse
        # Source Nodes: [matmul_31], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf472, arg228_1, buf480, 2508800, grid=grid(2508800), stream=stream0)
        del arg228_1
        buf481 = reinterpret_tensor(buf470, (1600, 49, 32), (1568, 32, 1), 0); del buf470  # reuse
        # Source Nodes: [matmul_31], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf479, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf480, (1600, 49, 32), (1600, 32, 1), 0), out=buf481)
        buf482 = buf453; del buf453  # reuse
        # Source Nodes: [x_300], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf481, buf482, 2508800, grid=grid(2508800), stream=stream0)
        buf483 = buf454; del buf454  # reuse
        # Source Nodes: [x_301], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg231_1, buf483, 262144, grid=grid(262144), stream=stream0)
        del arg231_1
        buf484 = reinterpret_tensor(buf481, (4900, 512), (512, 1), 0); del buf481  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf482, (4900, 512), (512, 1), 0), reinterpret_tensor(buf483, (512, 512), (1, 512), 0), out=buf484)
        buf489 = buf438; del buf438  # reuse
        # Source Nodes: [layer_norm_36, x_308, x_309], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf465, buf484, arg232_1, arg233_1, arg234_1, buf489, 4624, 512, grid=grid(4624), stream=stream0)
        del arg233_1
        del arg234_1
        buf490 = reinterpret_tensor(buf463, (2048, 512), (512, 1), 0); del buf463  # reuse
        # Source Nodes: [x_309], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg235_1, buf490, 1048576, grid=grid(1048576), stream=stream0)
        del arg235_1
        buf491 = reinterpret_tensor(buf462, (4624, 2048), (2048, 1), 0); del buf462  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf489, (4624, 512), (512, 1), 0), reinterpret_tensor(buf490, (512, 2048), (1, 512), 0), out=buf491)
        buf492 = reinterpret_tensor(buf491, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf491  # reuse
        # Source Nodes: [x_310], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf492, arg236_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg236_1
        buf493 = reinterpret_tensor(buf490, (512, 2048), (2048, 1), 0); del buf490  # reuse
        # Source Nodes: [x_312], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg237_1, buf493, 1048576, grid=grid(1048576), stream=stream0)
        del arg237_1
        buf494 = reinterpret_tensor(buf489, (4624, 512), (512, 1), 0); del buf489  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf492, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf493, (2048, 512), (1, 2048), 0), out=buf494)
        buf495 = reinterpret_tensor(buf494, (1, 4624, 512), (2367488, 512, 1), 0); del buf494  # reuse
        buf496 = buf467; del buf467  # reuse
        buf497 = buf466; del buf466  # reuse
        # Source Nodes: [x_308, x_314, x_315], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf495, buf465, buf484, arg232_1, arg238_1, buf496, buf497, 4624, 512, grid=grid(4624), stream=stream0)
        del arg232_1
        del arg238_1
        buf499 = reinterpret_tensor(buf484, (100, 49, 512), (25088, 512, 1), 0); del buf484  # reuse
        # Source Nodes: [linear_66], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf495, buf496, buf497, arg239_1, arg240_1, buf499, 2508800, grid=grid(2508800), stream=stream0)
        del arg239_1
        del arg240_1
        buf500 = buf471; del buf471  # reuse
        # Source Nodes: [linear_66], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg241_1, buf500, 786432, grid=grid(786432), stream=stream0)
        del arg241_1
        buf501 = buf472; del buf472  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf499, (4900, 512), (512, 1), 0), reinterpret_tensor(buf500, (512, 1536), (1, 512), 0), out=buf501)
        buf502 = buf480; del buf480  # reuse
        # Source Nodes: [attn_80, q_33], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf501, arg242_1, buf502, 2508800, grid=grid(2508800), stream=stream0)
        buf503 = reinterpret_tensor(buf473, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf473  # reuse
        # Source Nodes: [attn_80], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf501, arg242_1, buf503, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf504 = buf475; del buf475  # reuse
        # Source Nodes: [attn_80], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf502, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf503, (1600, 32, 49), (1600, 49, 1), 0), out=buf504)
        buf507 = buf479; del buf479  # reuse
        # Source Nodes: [attn_81, attn_82, matmul_33], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf504, arg244_1, arg243_1, buf507, 78400, 49, grid=grid(78400), stream=stream0)
        del arg243_1
        del arg244_1
        buf508 = reinterpret_tensor(buf503, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf503  # reuse
        # Source Nodes: [matmul_33], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf501, arg242_1, buf508, 2508800, grid=grid(2508800), stream=stream0)
        del arg242_1
        buf509 = reinterpret_tensor(buf499, (1600, 49, 32), (1568, 32, 1), 0); del buf499  # reuse
        # Source Nodes: [matmul_33], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf507, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf508, (1600, 49, 32), (1600, 32, 1), 0), out=buf509)
        buf510 = buf482; del buf482  # reuse
        # Source Nodes: [x_319], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf509, buf510, 2508800, grid=grid(2508800), stream=stream0)
        buf511 = buf483; del buf483  # reuse
        # Source Nodes: [x_320], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg245_1, buf511, 262144, grid=grid(262144), stream=stream0)
        del arg245_1
        buf512 = reinterpret_tensor(buf509, (4900, 512), (512, 1), 0); del buf509  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf510, (4900, 512), (512, 1), 0), reinterpret_tensor(buf511, (512, 512), (1, 512), 0), out=buf512)
        buf516 = buf465; del buf465  # reuse
        # Source Nodes: [layer_norm_38, x_326, x_327], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf495, buf512, arg246_1, arg247_1, arg248_1, buf516, 4624, 512, grid=grid(4624), stream=stream0)
        del arg247_1
        del arg248_1
        buf517 = reinterpret_tensor(buf493, (2048, 512), (512, 1), 0); del buf493  # reuse
        # Source Nodes: [x_327], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg249_1, buf517, 1048576, grid=grid(1048576), stream=stream0)
        del arg249_1
        buf518 = reinterpret_tensor(buf492, (4624, 2048), (2048, 1), 0); del buf492  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf516, (4624, 512), (512, 1), 0), reinterpret_tensor(buf517, (512, 2048), (1, 512), 0), out=buf518)
        buf519 = reinterpret_tensor(buf518, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf518  # reuse
        # Source Nodes: [x_328], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf519, arg250_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg250_1
        buf520 = reinterpret_tensor(buf517, (512, 2048), (2048, 1), 0); del buf517  # reuse
        # Source Nodes: [x_330], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg251_1, buf520, 1048576, grid=grid(1048576), stream=stream0)
        del arg251_1
        buf521 = reinterpret_tensor(buf516, (4624, 512), (512, 1), 0); del buf516  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf519, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf520, (2048, 512), (1, 2048), 0), out=buf521)
        buf522 = reinterpret_tensor(buf521, (1, 4624, 512), (2367488, 512, 1), 0); del buf521  # reuse
        buf523 = buf497; del buf497  # reuse
        buf524 = buf496; del buf496  # reuse
        # Source Nodes: [x_326, x_332, x_333], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf522, buf495, buf512, arg246_1, arg252_1, buf523, buf524, 4624, 512, grid=grid(4624), stream=stream0)
        del arg246_1
        del arg252_1
        buf526 = buf469; del buf469  # reuse
        # Source Nodes: [shifted_x_8, x_335], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf522, buf523, buf524, arg253_1, arg254_1, buf526, 2508800, grid=grid(2508800), stream=stream0)
        del arg253_1
        del arg254_1
        buf527 = reinterpret_tensor(buf512, (100, 49, 512), (25088, 512, 1), 0); del buf512  # reuse
        # Source Nodes: [linear_70], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf526, buf527, 2508800, grid=grid(2508800), stream=stream0)
        buf528 = buf500; del buf500  # reuse
        # Source Nodes: [linear_70], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg255_1, buf528, 786432, grid=grid(786432), stream=stream0)
        del arg255_1
        buf529 = buf501; del buf501  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf527, (4900, 512), (512, 1), 0), reinterpret_tensor(buf528, (512, 1536), (1, 512), 0), out=buf529)
        buf530 = buf508; del buf508  # reuse
        # Source Nodes: [attn_84, q_35], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf529, arg256_1, buf530, 2508800, grid=grid(2508800), stream=stream0)
        buf531 = reinterpret_tensor(buf502, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf502  # reuse
        # Source Nodes: [attn_84], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf529, arg256_1, buf531, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf532 = buf504; del buf504  # reuse
        # Source Nodes: [attn_84], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf530, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf531, (1600, 32, 49), (1600, 49, 1), 0), out=buf532)
        buf536 = buf507; del buf507  # reuse
        # Source Nodes: [attn_86, attn_88, matmul_35], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf532, arg258_1, arg257_1, buf190, buf536, 78400, 49, grid=grid(78400), stream=stream0)
        del arg257_1
        del arg258_1
        buf537 = reinterpret_tensor(buf531, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf531  # reuse
        # Source Nodes: [matmul_35], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf529, arg256_1, buf537, 2508800, grid=grid(2508800), stream=stream0)
        del arg256_1
        buf538 = reinterpret_tensor(buf527, (1600, 49, 32), (1568, 32, 1), 0); del buf527  # reuse
        # Source Nodes: [matmul_35], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf536, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf537, (1600, 49, 32), (1600, 32, 1), 0), out=buf538)
        buf539 = buf510; del buf510  # reuse
        # Source Nodes: [x_337], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf538, buf539, 2508800, grid=grid(2508800), stream=stream0)
        buf540 = buf511; del buf511  # reuse
        # Source Nodes: [x_338], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg259_1, buf540, 262144, grid=grid(262144), stream=stream0)
        del arg259_1
        buf541 = reinterpret_tensor(buf538, (4900, 512), (512, 1), 0); del buf538  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf539, (4900, 512), (512, 1), 0), reinterpret_tensor(buf540, (512, 512), (1, 512), 0), out=buf541)
        buf546 = buf495; del buf495  # reuse
        # Source Nodes: [layer_norm_40, x_345, x_346], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf522, buf541, arg260_1, arg261_1, arg262_1, buf546, 4624, 512, grid=grid(4624), stream=stream0)
        del arg261_1
        del arg262_1
        buf547 = reinterpret_tensor(buf520, (2048, 512), (512, 1), 0); del buf520  # reuse
        # Source Nodes: [x_346], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg263_1, buf547, 1048576, grid=grid(1048576), stream=stream0)
        del arg263_1
        buf548 = reinterpret_tensor(buf519, (4624, 2048), (2048, 1), 0); del buf519  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf546, (4624, 512), (512, 1), 0), reinterpret_tensor(buf547, (512, 2048), (1, 512), 0), out=buf548)
        buf549 = reinterpret_tensor(buf548, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf548  # reuse
        # Source Nodes: [x_347], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf549, arg264_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg264_1
        buf550 = reinterpret_tensor(buf547, (512, 2048), (2048, 1), 0); del buf547  # reuse
        # Source Nodes: [x_349], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg265_1, buf550, 1048576, grid=grid(1048576), stream=stream0)
        del arg265_1
        buf551 = reinterpret_tensor(buf546, (4624, 512), (512, 1), 0); del buf546  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf549, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf550, (2048, 512), (1, 2048), 0), out=buf551)
        buf552 = reinterpret_tensor(buf551, (1, 4624, 512), (2367488, 512, 1), 0); del buf551  # reuse
        buf553 = buf524; del buf524  # reuse
        buf554 = buf523; del buf523  # reuse
        # Source Nodes: [x_345, x_351, x_352], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf552, buf522, buf541, arg260_1, arg266_1, buf553, buf554, 4624, 512, grid=grid(4624), stream=stream0)
        del arg260_1
        del arg266_1
        buf556 = reinterpret_tensor(buf541, (100, 49, 512), (25088, 512, 1), 0); del buf541  # reuse
        # Source Nodes: [linear_74], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf552, buf553, buf554, arg267_1, arg268_1, buf556, 2508800, grid=grid(2508800), stream=stream0)
        del arg267_1
        del arg268_1
        buf557 = buf528; del buf528  # reuse
        # Source Nodes: [linear_74], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg269_1, buf557, 786432, grid=grid(786432), stream=stream0)
        del arg269_1
        buf558 = buf529; del buf529  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf556, (4900, 512), (512, 1), 0), reinterpret_tensor(buf557, (512, 1536), (1, 512), 0), out=buf558)
        buf559 = buf537; del buf537  # reuse
        # Source Nodes: [attn_90, q_37], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf558, arg270_1, buf559, 2508800, grid=grid(2508800), stream=stream0)
        buf560 = reinterpret_tensor(buf530, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf530  # reuse
        # Source Nodes: [attn_90], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf558, arg270_1, buf560, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf561 = buf532; del buf532  # reuse
        # Source Nodes: [attn_90], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf559, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf560, (1600, 32, 49), (1600, 49, 1), 0), out=buf561)
        buf564 = buf536; del buf536  # reuse
        # Source Nodes: [attn_91, attn_92, matmul_37], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf561, arg272_1, arg271_1, buf564, 78400, 49, grid=grid(78400), stream=stream0)
        del arg271_1
        del arg272_1
        buf565 = reinterpret_tensor(buf560, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf560  # reuse
        # Source Nodes: [matmul_37], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf558, arg270_1, buf565, 2508800, grid=grid(2508800), stream=stream0)
        del arg270_1
        buf566 = reinterpret_tensor(buf556, (1600, 49, 32), (1568, 32, 1), 0); del buf556  # reuse
        # Source Nodes: [matmul_37], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf564, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf565, (1600, 49, 32), (1600, 32, 1), 0), out=buf566)
        buf567 = buf539; del buf539  # reuse
        # Source Nodes: [x_356], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf566, buf567, 2508800, grid=grid(2508800), stream=stream0)
        buf568 = buf540; del buf540  # reuse
        # Source Nodes: [x_357], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg273_1, buf568, 262144, grid=grid(262144), stream=stream0)
        del arg273_1
        buf569 = reinterpret_tensor(buf566, (4900, 512), (512, 1), 0); del buf566  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf567, (4900, 512), (512, 1), 0), reinterpret_tensor(buf568, (512, 512), (1, 512), 0), out=buf569)
        buf573 = buf522; del buf522  # reuse
        # Source Nodes: [layer_norm_42, x_363, x_364], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf552, buf569, arg274_1, arg275_1, arg276_1, buf573, 4624, 512, grid=grid(4624), stream=stream0)
        del arg275_1
        del arg276_1
        buf574 = reinterpret_tensor(buf550, (2048, 512), (512, 1), 0); del buf550  # reuse
        # Source Nodes: [x_364], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg277_1, buf574, 1048576, grid=grid(1048576), stream=stream0)
        del arg277_1
        buf575 = reinterpret_tensor(buf549, (4624, 2048), (2048, 1), 0); del buf549  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf573, (4624, 512), (512, 1), 0), reinterpret_tensor(buf574, (512, 2048), (1, 512), 0), out=buf575)
        buf576 = reinterpret_tensor(buf575, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf575  # reuse
        # Source Nodes: [x_365], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf576, arg278_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg278_1
        buf577 = reinterpret_tensor(buf574, (512, 2048), (2048, 1), 0); del buf574  # reuse
        # Source Nodes: [x_367], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg279_1, buf577, 1048576, grid=grid(1048576), stream=stream0)
        del arg279_1
        buf578 = reinterpret_tensor(buf573, (4624, 512), (512, 1), 0); del buf573  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf576, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf577, (2048, 512), (1, 2048), 0), out=buf578)
        buf579 = reinterpret_tensor(buf578, (1, 4624, 512), (2367488, 512, 1), 0); del buf578  # reuse
        buf580 = buf554; del buf554  # reuse
        buf581 = buf553; del buf553  # reuse
        # Source Nodes: [x_363, x_369, x_370], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf579, buf552, buf569, arg274_1, arg280_1, buf580, buf581, 4624, 512, grid=grid(4624), stream=stream0)
        del arg274_1
        del arg280_1
        buf583 = buf526; del buf526  # reuse
        # Source Nodes: [shifted_x_9, x_372], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf579, buf580, buf581, arg281_1, arg282_1, buf583, 2508800, grid=grid(2508800), stream=stream0)
        del arg281_1
        del arg282_1
        buf584 = reinterpret_tensor(buf569, (100, 49, 512), (25088, 512, 1), 0); del buf569  # reuse
        # Source Nodes: [linear_78], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf583, buf584, 2508800, grid=grid(2508800), stream=stream0)
        buf585 = buf557; del buf557  # reuse
        # Source Nodes: [linear_78], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg283_1, buf585, 786432, grid=grid(786432), stream=stream0)
        del arg283_1
        buf586 = buf558; del buf558  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf584, (4900, 512), (512, 1), 0), reinterpret_tensor(buf585, (512, 1536), (1, 512), 0), out=buf586)
        buf587 = buf565; del buf565  # reuse
        # Source Nodes: [attn_94, q_39], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf586, arg284_1, buf587, 2508800, grid=grid(2508800), stream=stream0)
        buf588 = reinterpret_tensor(buf559, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf559  # reuse
        # Source Nodes: [attn_94], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf586, arg284_1, buf588, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf589 = buf561; del buf561  # reuse
        # Source Nodes: [attn_94], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf587, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf588, (1600, 32, 49), (1600, 49, 1), 0), out=buf589)
        buf593 = buf564; del buf564  # reuse
        # Source Nodes: [attn_96, attn_98, matmul_39], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf589, arg286_1, arg285_1, buf190, buf593, 78400, 49, grid=grid(78400), stream=stream0)
        del arg285_1
        del arg286_1
        buf594 = reinterpret_tensor(buf588, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf588  # reuse
        # Source Nodes: [matmul_39], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf586, arg284_1, buf594, 2508800, grid=grid(2508800), stream=stream0)
        del arg284_1
        buf595 = reinterpret_tensor(buf584, (1600, 49, 32), (1568, 32, 1), 0); del buf584  # reuse
        # Source Nodes: [matmul_39], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf593, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf594, (1600, 49, 32), (1600, 32, 1), 0), out=buf595)
        buf596 = buf567; del buf567  # reuse
        # Source Nodes: [x_374], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf595, buf596, 2508800, grid=grid(2508800), stream=stream0)
        buf597 = buf568; del buf568  # reuse
        # Source Nodes: [x_375], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg287_1, buf597, 262144, grid=grid(262144), stream=stream0)
        del arg287_1
        buf598 = reinterpret_tensor(buf595, (4900, 512), (512, 1), 0); del buf595  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf596, (4900, 512), (512, 1), 0), reinterpret_tensor(buf597, (512, 512), (1, 512), 0), out=buf598)
        buf603 = buf552; del buf552  # reuse
        # Source Nodes: [layer_norm_44, x_382, x_383], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf579, buf598, arg288_1, arg289_1, arg290_1, buf603, 4624, 512, grid=grid(4624), stream=stream0)
        del arg289_1
        del arg290_1
        buf604 = reinterpret_tensor(buf577, (2048, 512), (512, 1), 0); del buf577  # reuse
        # Source Nodes: [x_383], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg291_1, buf604, 1048576, grid=grid(1048576), stream=stream0)
        del arg291_1
        buf605 = reinterpret_tensor(buf576, (4624, 2048), (2048, 1), 0); del buf576  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf603, (4624, 512), (512, 1), 0), reinterpret_tensor(buf604, (512, 2048), (1, 512), 0), out=buf605)
        buf606 = reinterpret_tensor(buf605, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf605  # reuse
        # Source Nodes: [x_384], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf606, arg292_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg292_1
        buf607 = reinterpret_tensor(buf604, (512, 2048), (2048, 1), 0); del buf604  # reuse
        # Source Nodes: [x_386], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg293_1, buf607, 1048576, grid=grid(1048576), stream=stream0)
        del arg293_1
        buf608 = reinterpret_tensor(buf603, (4624, 512), (512, 1), 0); del buf603  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf606, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf607, (2048, 512), (1, 2048), 0), out=buf608)
        buf609 = reinterpret_tensor(buf608, (1, 4624, 512), (2367488, 512, 1), 0); del buf608  # reuse
        buf610 = buf581; del buf581  # reuse
        buf611 = buf580; del buf580  # reuse
        # Source Nodes: [x_382, x_388, x_389], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_69.run(buf609, buf579, buf598, arg288_1, arg294_1, buf610, buf611, 4624, 512, grid=grid(4624), stream=stream0)
        del arg288_1
        del arg294_1
        buf613 = reinterpret_tensor(buf598, (100, 49, 512), (25088, 512, 1), 0); del buf598  # reuse
        # Source Nodes: [linear_82], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_51.run(buf609, buf610, buf611, arg295_1, arg296_1, buf613, 2508800, grid=grid(2508800), stream=stream0)
        del arg295_1
        del arg296_1
        buf614 = buf585; del buf585  # reuse
        # Source Nodes: [linear_82], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg297_1, buf614, 786432, grid=grid(786432), stream=stream0)
        del arg297_1
        buf615 = buf586; del buf586  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf613, (4900, 512), (512, 1), 0), reinterpret_tensor(buf614, (512, 1536), (1, 512), 0), out=buf615)
        buf616 = buf594; del buf594  # reuse
        # Source Nodes: [attn_100, q_41], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf615, arg298_1, buf616, 2508800, grid=grid(2508800), stream=stream0)
        buf617 = reinterpret_tensor(buf587, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf587  # reuse
        # Source Nodes: [attn_100], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf615, arg298_1, buf617, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf618 = buf589; del buf589  # reuse
        # Source Nodes: [attn_100], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf616, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf617, (1600, 32, 49), (1600, 49, 1), 0), out=buf618)
        buf621 = buf593; del buf593  # reuse
        # Source Nodes: [attn_101, attn_102, matmul_41], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_55.run(buf618, arg300_1, arg299_1, buf621, 78400, 49, grid=grid(78400), stream=stream0)
        del arg299_1
        del arg300_1
        buf622 = reinterpret_tensor(buf617, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf617  # reuse
        # Source Nodes: [matmul_41], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf615, arg298_1, buf622, 2508800, grid=grid(2508800), stream=stream0)
        del arg298_1
        buf623 = reinterpret_tensor(buf613, (1600, 49, 32), (1568, 32, 1), 0); del buf613  # reuse
        # Source Nodes: [matmul_41], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf621, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf622, (1600, 49, 32), (1600, 32, 1), 0), out=buf623)
        buf624 = buf596; del buf596  # reuse
        # Source Nodes: [x_393], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf623, buf624, 2508800, grid=grid(2508800), stream=stream0)
        buf625 = buf597; del buf597  # reuse
        # Source Nodes: [x_394], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg301_1, buf625, 262144, grid=grid(262144), stream=stream0)
        del arg301_1
        buf626 = reinterpret_tensor(buf623, (4900, 512), (512, 1), 0); del buf623  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf624, (4900, 512), (512, 1), 0), reinterpret_tensor(buf625, (512, 512), (1, 512), 0), out=buf626)
        buf630 = buf579; del buf579  # reuse
        # Source Nodes: [layer_norm_46, x_400, x_401], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_58.run(buf609, buf626, arg302_1, arg303_1, arg304_1, buf630, 4624, 512, grid=grid(4624), stream=stream0)
        del arg303_1
        del arg304_1
        buf631 = reinterpret_tensor(buf607, (2048, 512), (512, 1), 0); del buf607  # reuse
        # Source Nodes: [x_401], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg305_1, buf631, 1048576, grid=grid(1048576), stream=stream0)
        del arg305_1
        buf632 = reinterpret_tensor(buf606, (4624, 2048), (2048, 1), 0); del buf606  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf630, (4624, 512), (512, 1), 0), reinterpret_tensor(buf631, (512, 2048), (1, 512), 0), out=buf632)
        buf633 = reinterpret_tensor(buf632, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf632  # reuse
        # Source Nodes: [x_402], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf633, arg306_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg306_1
        buf634 = reinterpret_tensor(buf631, (512, 2048), (2048, 1), 0); del buf631  # reuse
        # Source Nodes: [x_404], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg307_1, buf634, 1048576, grid=grid(1048576), stream=stream0)
        del arg307_1
        buf635 = reinterpret_tensor(buf630, (4624, 512), (512, 1), 0); del buf630  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf633, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf634, (2048, 512), (1, 2048), 0), out=buf635)
        buf636 = reinterpret_tensor(buf635, (1, 4624, 512), (2367488, 512, 1), 0); del buf635  # reuse
        buf637 = buf611; del buf611  # reuse
        buf638 = buf610; del buf610  # reuse
        # Source Nodes: [x_400, x_406, x_407], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_61.run(buf636, buf609, buf626, arg302_1, arg308_1, buf637, buf638, 4624, 512, grid=grid(4624), stream=stream0)
        del arg302_1
        del arg308_1
        buf640 = buf583; del buf583  # reuse
        # Source Nodes: [shifted_x_10, x_409], Original ATen: [aten.constant_pad_nd, aten.roll]
        triton_poi_fused_constant_pad_nd_roll_62.run(buf636, buf637, buf638, arg309_1, arg310_1, buf640, 2508800, grid=grid(2508800), stream=stream0)
        del arg309_1
        del arg310_1
        buf641 = reinterpret_tensor(buf626, (100, 49, 512), (25088, 512, 1), 0); del buf626  # reuse
        # Source Nodes: [linear_86], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_63.run(buf640, buf641, 2508800, grid=grid(2508800), stream=stream0)
        del buf640
        buf642 = buf614; del buf614  # reuse
        # Source Nodes: [linear_86], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_52.run(arg311_1, buf642, 786432, grid=grid(786432), stream=stream0)
        del arg311_1
        buf643 = buf615; del buf615  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf641, (4900, 512), (512, 1), 0), reinterpret_tensor(buf642, (512, 1536), (1, 512), 0), out=buf643)
        del buf642
        buf644 = buf622; del buf622  # reuse
        # Source Nodes: [attn_104, q_43], Original ATen: [aten.clone, aten.mul]
        triton_poi_fused_clone_mul_53.run(buf643, arg312_1, buf644, 2508800, grid=grid(2508800), stream=stream0)
        buf645 = reinterpret_tensor(buf616, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf616  # reuse
        # Source Nodes: [attn_104], Original ATen: [aten.clone]
        triton_poi_fused_clone_54.run(buf643, arg312_1, buf645, 51200, 49, grid=grid(51200, 49), stream=stream0)
        buf646 = buf618; del buf618  # reuse
        # Source Nodes: [attn_104], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf644, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf645, (1600, 32, 49), (1600, 49, 1), 0), out=buf646)
        del buf644
        buf650 = buf621; del buf621  # reuse
        # Source Nodes: [attn_106, attn_108, matmul_43], Original ATen: [aten._softmax, aten._to_copy, aten.add]
        triton_per_fused__softmax__to_copy_add_67.run(buf646, arg314_1, arg313_1, buf190, buf650, 78400, 49, grid=grid(78400), stream=stream0)
        del arg313_1
        del arg314_1
        del buf190
        del buf646
        buf651 = reinterpret_tensor(buf645, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf645  # reuse
        # Source Nodes: [matmul_43], Original ATen: [aten.clone]
        triton_poi_fused_clone_56.run(buf643, arg312_1, buf651, 2508800, grid=grid(2508800), stream=stream0)
        del arg312_1
        del buf643
        buf652 = reinterpret_tensor(buf641, (1600, 49, 32), (1568, 32, 1), 0); del buf641  # reuse
        # Source Nodes: [matmul_43], Original ATen: [aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(buf650, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf651, (1600, 49, 32), (1600, 32, 1), 0), out=buf652)
        del buf650
        del buf651
        buf653 = buf624; del buf624  # reuse
        # Source Nodes: [x_411], Original ATen: [aten.clone]
        triton_poi_fused_clone_57.run(buf652, buf653, 2508800, grid=grid(2508800), stream=stream0)
        buf654 = buf625; del buf625  # reuse
        # Source Nodes: [x_412], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_37.run(arg315_1, buf654, 262144, grid=grid(262144), stream=stream0)
        del arg315_1
        buf655 = reinterpret_tensor(buf652, (4900, 512), (512, 1), 0); del buf652  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf653, (4900, 512), (512, 1), 0), reinterpret_tensor(buf654, (512, 512), (1, 512), 0), out=buf655)
        del buf653
        del buf654
        buf660 = buf609; del buf609  # reuse
        # Source Nodes: [layer_norm_48, x_419, x_420], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_red_fused__to_copy_add_native_layer_norm_68.run(buf636, buf655, arg316_1, arg317_1, arg318_1, buf660, 4624, 512, grid=grid(4624), stream=stream0)
        del arg317_1
        del arg318_1
        buf661 = reinterpret_tensor(buf634, (2048, 512), (512, 1), 0); del buf634  # reuse
        # Source Nodes: [x_420], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg319_1, buf661, 1048576, grid=grid(1048576), stream=stream0)
        del arg319_1
        buf662 = reinterpret_tensor(buf633, (4624, 2048), (2048, 1), 0); del buf633  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf660, (4624, 512), (512, 1), 0), reinterpret_tensor(buf661, (512, 2048), (1, 512), 0), out=buf662)
        buf663 = reinterpret_tensor(buf662, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf662  # reuse
        # Source Nodes: [x_421], Original ATen: [aten.gelu]
        triton_poi_fused_gelu_60.run(buf663, arg320_1, 9469952, grid=grid(9469952), stream=stream0)
        del arg320_1
        buf664 = reinterpret_tensor(buf661, (512, 2048), (2048, 1), 0); del buf661  # reuse
        # Source Nodes: [x_423], Original ATen: [aten._to_copy]
        triton_poi_fused__to_copy_59.run(arg321_1, buf664, 1048576, grid=grid(1048576), stream=stream0)
        del arg321_1
        buf665 = reinterpret_tensor(buf660, (4624, 512), (512, 1), 0); del buf660  # reuse
        # Source Nodes: [], Original ATen: []
        extern_kernels.mm(reinterpret_tensor(buf663, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf664, (2048, 512), (1, 2048), 0), out=buf665)
        del buf663
        del buf664
        buf666 = empty_strided_cuda((1, 4624, 512), (2367488, 512, 1), torch.float32)
        buf667 = buf638; del buf638  # reuse
        buf668 = buf637; del buf637  # reuse
        # Source Nodes: [x_419, x_425, x_out_2], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
        triton_per_fused__to_copy_add_native_layer_norm_70.run(buf636, buf655, arg316_1, buf665, arg322_1, buf666, buf667, buf668, 4624, 512, grid=grid(4624), stream=stream0)
        del arg316_1
        del arg322_1
        del buf636
        del buf655
        del buf665
        buf670 = reinterpret_tensor(buf34, (1, 128, 272, 272), (9469952, 73984, 272, 1), 0); del buf34  # reuse
        # Source Nodes: [out], Original ATen: [aten.clone]
        triton_poi_fused_clone_71.run(buf68, buf72, buf73, arg36_1, arg37_1, buf670, 128, 73984, grid=grid(128, 73984), stream=stream0)
        del arg36_1
        del arg37_1
        del buf68
        del buf72
        del buf73
        buf671 = empty_strided_cuda((1, 256, 136, 136), (4734976, 18496, 136, 1), torch.float32)
        # Source Nodes: [out_1], Original ATen: [aten.clone]
        triton_poi_fused_clone_72.run(buf139, buf143, buf144, arg69_1, arg70_1, buf671, 256, 18496, grid=grid(256, 18496), stream=stream0)
        del arg69_1
        del arg70_1
        del buf139
        del buf143
        del buf144
        buf672 = empty_strided_cuda((1, 512, 68, 68), (2367488, 4624, 68, 1), torch.float32)
        # Source Nodes: [out_2], Original ATen: [aten.clone]
        triton_poi_fused_clone_73.run(buf666, buf667, buf668, arg323_1, arg324_1, buf672, 512, 4624, grid=grid(512, 4624), stream=stream0)
        del arg323_1
        del arg324_1
        del buf666
        del buf667
        del buf668
    return (buf670, buf671, buf672, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    arg0_1 = rand_strided((1, 3, 1088, 1088), (3551232, 1183744, 1088, 1), device='cuda:0', dtype=torch.float32)
    arg1_1 = rand_strided((128, 3, 4, 4), (48, 16, 4, 1), device='cuda:0', dtype=torch.float32)
    arg2_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg3_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg4_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg5_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg6_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg7_1 = rand_strided((384, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg8_1 = rand_strided((384, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg9_1 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32)
    arg10_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg11_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg12_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg13_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg14_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg15_1 = rand_strided((512, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg16_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg17_1 = rand_strided((128, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg18_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg19_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg20_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg21_1 = rand_strided((384, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg22_1 = rand_strided((384, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg23_1 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32)
    arg24_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg25_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg26_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg27_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg28_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg29_1 = rand_strided((512, 128), (128, 1), device='cuda:0', dtype=torch.float32)
    arg30_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg31_1 = rand_strided((128, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg32_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg33_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg34_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg35_1 = rand_strided((256, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg36_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg37_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg38_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg39_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg40_1 = rand_strided((768, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg41_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg42_1 = rand_strided((169, 8), (8, 1), device='cuda:0', dtype=torch.float32)
    arg43_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg44_1 = rand_strided((256, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg45_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg46_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg47_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg48_1 = rand_strided((1024, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg49_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg50_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
    arg51_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg52_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg53_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg54_1 = rand_strided((768, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg55_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg56_1 = rand_strided((169, 8), (8, 1), device='cuda:0', dtype=torch.float32)
    arg57_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg58_1 = rand_strided((256, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg59_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg60_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg61_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg62_1 = rand_strided((1024, 256), (256, 1), device='cuda:0', dtype=torch.float32)
    arg63_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg64_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
    arg65_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg66_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg67_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg68_1 = rand_strided((512, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
    arg69_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg70_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg71_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg72_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg73_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg74_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg75_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg76_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg77_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg78_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg79_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg80_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg81_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg82_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg83_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg84_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg85_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg86_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg87_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg88_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg89_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg90_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg91_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg92_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg93_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg94_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg95_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg96_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg97_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg98_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg99_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg100_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg101_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg102_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg103_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg104_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg105_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg106_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg107_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg108_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg109_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg110_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg111_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg112_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg113_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg114_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg115_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg116_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg117_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg118_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg119_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg120_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg121_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg122_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg123_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg124_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg125_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg126_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg127_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg128_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg129_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg130_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg131_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg132_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg133_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg134_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg135_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg136_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg137_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg138_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg139_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg140_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg141_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg142_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg143_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg144_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg145_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg146_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg147_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg148_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg149_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg150_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg151_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg152_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg153_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg154_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg155_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg156_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg157_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg158_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg159_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg160_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg161_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg162_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg163_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg164_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg165_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg166_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg167_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg168_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg169_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg170_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg171_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg172_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg173_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg174_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg175_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg176_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg177_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg178_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg179_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg180_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg181_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg182_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg183_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg184_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg185_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg186_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg187_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg188_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg189_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg190_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg191_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg192_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg193_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg194_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg195_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg196_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg197_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg198_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg199_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg200_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg201_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg202_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg203_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg204_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg205_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg206_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg207_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg208_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg209_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg210_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg211_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg212_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg213_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg214_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg215_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg216_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg217_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg218_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg219_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg220_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg221_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg222_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg223_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg224_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg225_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg226_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg227_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg228_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg229_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg230_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg231_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg232_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg233_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg234_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg235_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg236_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg237_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg238_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg239_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg240_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg241_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg242_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg243_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg244_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg245_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg246_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg247_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg248_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg249_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg250_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg251_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg252_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg253_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg254_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg255_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg256_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg257_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg258_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg259_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg260_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg261_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg262_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg263_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg264_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg265_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg266_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg267_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg268_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg269_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg270_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg271_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg272_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg273_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg274_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg275_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg276_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg277_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg278_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg279_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg280_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg281_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg282_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg283_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg284_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg285_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg286_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg287_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg288_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg289_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg290_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg291_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg292_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg293_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg294_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg295_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg296_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg297_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg298_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg299_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg300_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg301_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg302_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg303_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg304_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg305_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg306_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg307_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg308_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg309_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg310_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg311_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg312_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg313_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
    arg314_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
    arg315_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg316_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg317_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg318_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg319_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
    arg320_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg321_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
    arg322_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg323_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    arg324_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
    fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)