# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_root/ut/cutum24mmaatgtfsv4woxycqlzbsdogt2u54xhctaoi4vh3kdezb.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 3551232
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_root/5e/c5egziobbbb7h3rfbwmjwo2duj6cruzzve6uugmk3p43rvyyzsbe.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_1 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8192],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_1', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 6144
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/nv/cnvgdx7r4kd34wufst4iduuf34fhbxltx3bb4j6q4f62cqpgs6pw.py
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
# x => convert_element_type, convert_element_type_1, convert_element_type_2, convolution
triton_poi_fused__to_copy_convolution_2 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[128],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_convolution_2', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/qw/cqw52v5yidn33sc72vds5xzut4fwzpkiznabfnmbghkdfdx4dv2c.py
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_2 => clone, convert_element_type_3, var_mean
triton_red_fused__to_copy_native_layer_norm_3 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_3', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp5_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp5_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp5_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp4 = tl.broadcast_to(tmp3, [XBLOCK, RBLOCK])
tmp5_mean_next, tmp5_m2_next, tmp5_weight_next = triton_helpers.welford_reduce(
tmp4, tmp5_mean, tmp5_m2, tmp5_weight, roffset == 0
)
tmp5_mean = tl.where(rmask & xmask, tmp5_mean_next, tmp5_mean)
tmp5_m2 = tl.where(rmask & xmask, tmp5_m2_next, tmp5_m2)
tmp5_weight = tl.where(rmask & xmask, tmp5_weight_next, tmp5_weight)
tmp5_tmp, tmp6_tmp, tmp7_tmp = triton_helpers.welford(
tmp5_mean, tmp5_m2, tmp5_weight, 1
)
tmp5 = tmp5_tmp[:, None]
tmp6 = tmp6_tmp[:, None]
tmp7 = tmp7_tmp[:, None]
tl.store(out_ptr0 + (x0), tmp5, xmask)
tl.store(out_ptr1 + (x0), tmp6, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/z7/cz7aixplgjdoqi3iyr5pyvohfufykasw4xp6tdzqkrnniybuo3jw.py
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_2 => add, add_1, clone, convert_element_type_3, mul, mul_1, rsqrt, sub, var_mean
triton_poi_fused__to_copy_native_layer_norm_4 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_native_layer_norm_4', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 6, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9469952
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 73984)
x0 = xindex % 73984
tmp0 = tl.load(in_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x1), None, eviction_policy='evict_last').to(tl.float32)
tmp4 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp6 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last')
tmp15 = tl.load(in_ptr5 + (x1), None, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tmp3 = tmp2.to(tl.float32)
tmp5 = tmp3 - tmp4
tmp7 = 128.0
tmp8 = tmp6 / tmp7
tmp9 = 1e-05
tmp10 = tmp8 + tmp9
tmp11 = libdevice.rsqrt(tmp10)
tmp12 = tmp5 * tmp11
tmp14 = tmp12 * tmp13
tmp16 = tmp14 + tmp15
tl.store(out_ptr0 + (x2), tmp16, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/zd/czdqvqzahblfj2alw5l62ev6piz5bnkywru3bk67m4u7dfxo6ukh.py
# Source Nodes: [x_7], Original ATen: [aten.native_layer_norm]
# x_7 => var_mean_1
triton_red_fused_native_layer_norm_5 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_native_layer_norm_5', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp2_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp2_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp2_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp2_mean_next, tmp2_m2_next, tmp2_weight_next = triton_helpers.welford_reduce(
tmp1, tmp2_mean, tmp2_m2, tmp2_weight, roffset == 0
)
tmp2_mean = tl.where(rmask & xmask, tmp2_mean_next, tmp2_mean)
tmp2_m2 = tl.where(rmask & xmask, tmp2_m2_next, tmp2_m2)
tmp2_weight = tl.where(rmask & xmask, tmp2_weight_next, tmp2_weight)
tmp2_tmp, tmp3_tmp, tmp4_tmp = triton_helpers.welford(
tmp2_mean, tmp2_m2, tmp2_weight, 1
)
tmp2 = tmp2_tmp[:, None]
tmp3 = tmp3_tmp[:, None]
tmp4 = tmp4_tmp[:, None]
tl.store(out_ptr0 + (x0), tmp2, xmask)
tl.store(out_ptr1 + (x0), tmp3, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/uz/cuzrg4rqsfxddxjfergppynngxa5vajxfugeyyhxthnn5pkepo7w.py
# Source Nodes: [linear], Original ATen: [aten._to_copy]
# linear => convert_element_type_8
triton_poi_fused__to_copy_6 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[131072, 128], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_6', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 74529
xnumel = 128
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
y0 = yindex % 49
y1 = (yindex // 49)
x2 = xindex
y3 = yindex
tmp0 = (7*(y1 // 39)) + (y0 // 7)
tmp1 = tl.full([1, 1], 272, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (7*(y1 % 39)) + (y0 % 7)
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + ((272*((((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7)) // 272) % 272)) + (73984*x2) + (((7*(y1 % 39)) + (y0 % 7)) % 272)), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp7 = tl.load(in_ptr1 + (tl.broadcast_to((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7), [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp8 = tmp6 - tmp7
tmp9 = tl.load(in_ptr2 + (tl.broadcast_to((7*(y1 % 39)) + (272*(y0 // 7)) + (1904*(y1 // 39)) + (y0 % 7), [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp10 = 128.0
tmp11 = tmp9 / tmp10
tmp12 = 1e-05
tmp13 = tmp11 + tmp12
tmp14 = libdevice.rsqrt(tmp13)
tmp15 = tmp8 * tmp14
tmp16 = tl.load(in_ptr3 + (tl.broadcast_to(x2, [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp17 = tmp15 * tmp16
tmp18 = tl.load(in_ptr4 + (tl.broadcast_to(x2, [XBLOCK, YBLOCK])), tmp5 & xmask & ymask, eviction_policy='evict_last', other=0.0)
tmp19 = tmp17 + tmp18
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp5, tmp19, tmp20)
tmp22 = tmp21.to(tl.float32)
tl.store(out_ptr0 + (x2 + (128*y3)), tmp22, xmask & ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/27/c27kmexuvaczehtobz6jqsbfjvi5xgmlmysk5r6xkxskrs6dfhtn.py
# Source Nodes: [linear], Original ATen: [aten._to_copy]
# linear => convert_element_type_7
triton_poi_fused__to_copy_7 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[65536],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_7', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 49152
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/gv/cgvv6tj3pm4btpyl5nxivtckwp4tnmqcbgk3oirzkoxwjd7w4qhy.py
# Source Nodes: [attn, q_1], Original ATen: [aten.clone, aten.mul]
# attn => clone_4
# q_1 => mul_6
triton_poi_fused_clone_mul_8 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_8', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9539712
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 4
x3 = (xindex // 6272)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (384*x1) + (18816*x3)), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), xmask, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = 0.1767766952966369
tmp5 = tmp3 * tmp4
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/3d/c3dgt2c4b5437df63yihnutvowcnecyibipqaqopntg6wcxwli7t.py
# Source Nodes: [attn], Original ATen: [aten.clone]
# attn => clone_5
triton_poi_fused_clone_9 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[262144, 64], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_9', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 194688
xnumel = 49
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x3 = xindex
y2 = (yindex // 128)
y4 = yindex % 128
y0 = yindex % 32
y5 = (yindex // 32)
tmp0 = tl.load(in_ptr0 + (128 + y4 + (384*x3) + (18816*y2)), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (128 + y4), ymask, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask & ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/ja/cjajea6ta7qkhr6lfq7dh5ukfc3b2pbysx55ye2fbw22b5lnzzve.py
# Source Nodes: [attn_1, attn_2, matmul_1], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_1 => add_4
# attn_2 => amax, div_2, exp, sub_3, sum_1
# matmul_1 => convert_element_type_14
triton_per_fused__softmax__to_copy_add_10 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[524288, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_10', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 298116
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 4
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
tmp14 = tmp9 - tmp13
tmp15 = tl_math.exp(tmp14)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, 0)
tmp19 = tl.sum(tmp18, 1)[:, None]
tmp20 = tmp15 / tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/y3/cy3breg2mghmf5ikb42jodayxixk2k3gmcadoiwf3jyrnnxwruwr.py
# Source Nodes: [matmul_1], Original ATen: [aten.clone]
# matmul_1 => clone_8
triton_poi_fused_clone_11 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_11', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9539712
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 4
x3 = (xindex // 6272)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (256 + x0 + (32*x2) + (384*x1) + (18816*x3)), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (256 + x0 + (32*x2)), xmask, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/wx/cwxwgfbkicmgjn2axopqk6bihn72jskwcbv4lwok4w227xf3wkxj.py
# Source Nodes: [x_11], Original ATen: [aten.clone]
# x_11 => clone_9
triton_poi_fused_clone_12 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_12', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9539712
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 4
x2 = (xindex // 128) % 49
x3 = (xindex // 6272)
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (6272*x3)), xmask).to(tl.float32)
tl.store(out_ptr0 + (x4), tmp0, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/e5/ce5swurtaidsanunlofocegeahut6uzzovemh7x35gwa2mv3uvbw.py
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
# x_12 => convert_element_type_18
triton_poi_fused__to_copy_13 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16384],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_13', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 16384
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/6a/c6ak2mlzd7gs74dn6lkznsahwca2ehhzdludsreoi74bucu7ieca.py
# Source Nodes: [layer_norm_2, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_2 => add_6, add_7, mul_7, mul_8, rsqrt_2, sub_4, var_mean_2
# x_18 => add_5
# x_19 => convert_element_type_24
triton_red_fused__to_copy_add_native_layer_norm_14 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_14', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 + tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp12 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp15.to(tl.float32)
tmp17 = tmp11 + tmp16
tmp18 = tmp17 - tmp8
tmp19 = 128.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr2 + (r1 + (128*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/mj/cmjiy7nnzibomx47pwcnzzevlowthzxr6o6uin5k2sbiqh237ico.py
# Source Nodes: [x_19], Original ATen: [aten._to_copy]
# x_19 => convert_element_type_23
triton_poi_fused__to_copy_15 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[65536],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_15', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 65536
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/u7/cu7c5ukagtxwy7y4iiqkp65cte74cratlcairghkion2ijccm4hb.py
# Source Nodes: [x_20], Original ATen: [aten.gelu]
# x_20 => add_8, convert_element_type_28, convert_element_type_29, erf, mul_10, mul_11, mul_9
triton_poi_fused_gelu_16 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[67108864],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_16', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 37879808
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 512
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = 0.5
tmp6 = tmp4 * tmp5
tmp7 = 0.7071067811865476
tmp8 = tmp4 * tmp7
tmp9 = libdevice.erf(tmp8)
tmp10 = 1.0
tmp11 = tmp9 + tmp10
tmp12 = tmp6 * tmp11
tmp13 = tmp12.to(tl.float32)
tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/ce/ccebhjg73upjqgd74bporiec6gb6m46ufzq5fs4uybxgkhzqobna.py
# Source Nodes: [x_18, x_24, x_25], Original ATen: [aten.add, aten.native_layer_norm]
# x_18 => add_5
# x_24 => add_9
# x_25 => var_mean_3
triton_red_fused_add_native_layer_norm_17 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_native_layer_norm_17', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp14_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp14_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp14_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (x0 + (73984*r1)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (128*((x0 % 272) % 7)) + (896*((x0 // 272) % 7)) + (6272*((x0 % 272) // 7)) + (244608*(x0 // 1904))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp7 = tl.load(in_ptr3 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 + tmp5
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp7 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp6 + tmp11
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
tmp14_mean_next, tmp14_m2_next, tmp14_weight_next = triton_helpers.welford_reduce(
tmp13, tmp14_mean, tmp14_m2, tmp14_weight, roffset == 0
)
tmp14_mean = tl.where(rmask & xmask, tmp14_mean_next, tmp14_mean)
tmp14_m2 = tl.where(rmask & xmask, tmp14_m2_next, tmp14_m2)
tmp14_weight = tl.where(rmask & xmask, tmp14_weight_next, tmp14_weight)
tl.store(out_ptr0 + (r1 + (128*x0)), tmp12, rmask & xmask)
tmp14_tmp, tmp15_tmp, tmp16_tmp = triton_helpers.welford(
tmp14_mean, tmp14_m2, tmp14_weight, 1
)
tmp14 = tmp14_tmp[:, None]
tmp15 = tmp15_tmp[:, None]
tmp16 = tmp16_tmp[:, None]
tl.store(out_ptr1 + (x0), tmp14, xmask)
tl.store(out_ptr2 + (x0), tmp15, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/bd/cbdbrul6cgbc725lx7noa5ulqmr2r6ohwoqlh54cc2oipmh3exwr.py
# Source Nodes: [shifted_x, x_27], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x => index_1, index_2
# x_27 => constant_pad_nd_1
triton_poi_fused_constant_pad_nd_roll_18 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_18', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9539712
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = (xindex // 34944)
x1 = (xindex // 128) % 273
x0 = xindex % 128
x4 = xindex
tmp0 = (3 + x2) % 273
tmp1 = tl.full([1], 272, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (3 + x1) % 273
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + (x0 + (128*((3 + x1) % 273)) + (34816*((3 + x2) % 273))), tmp5 & xmask, other=0.0)
tmp7 = tl.load(in_ptr1 + ((272*((3 + x2) % 273)) + ((3 + x1) % 273)), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
tmp8 = tmp6 - tmp7
tmp9 = tl.load(in_ptr2 + ((272*((3 + x2) % 273)) + ((3 + x1) % 273)), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = 128.0
tmp11 = tmp9 / tmp10
tmp12 = 1e-05
tmp13 = tmp11 + tmp12
tmp14 = libdevice.rsqrt(tmp13)
tmp15 = tmp8 * tmp14
tmp16 = tl.load(in_ptr3 + (x0), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
tmp17 = tmp15 * tmp16
tmp18 = tl.load(in_ptr4 + (x0), tmp5 & xmask, eviction_policy='evict_last', other=0.0)
tmp19 = tmp17 + tmp18
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp5, tmp19, tmp20)
tl.store(out_ptr0 + (x4), tmp21, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/wf/cwfhqs2zm6bsjsrw2tiokspft5espgjofwlyywwnhdmmbyrpctxz.py
# Source Nodes: [linear_4], Original ATen: [aten._to_copy]
# linear_4 => convert_element_type_37
triton_poi_fused__to_copy_19 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_19', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9539712
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 128
x1 = (xindex // 128) % 49
x2 = (xindex // 6272)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (128*(x1 % 7)) + (896*(x2 % 39)) + (34944*(x1 // 7)) + (244608*(x2 // 39))), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x3), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/or/corkqb2j6hymgta3qqablpxtt4psncgc2x5wtkcbslwfpukdab7t.py
# Source Nodes: [img_mask, setitem, setitem_1, setitem_2, setitem_3, setitem_4], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask => full
# setitem => copy, lift_fresh_copy_2
# setitem_1 => copy_1, lift_fresh_copy_3
# setitem_2 => copy_2, lift_fresh_copy_4
# setitem_3 => copy_3, full_default, lift_fresh_copy_5
# setitem_4 => copy_4, lift_fresh_copy_6
triton_poi_fused_fill_lift_fresh_slice_zeros_20 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[131072],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_20', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 74529
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 273)
x0 = xindex % 273
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 266, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 270, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tmp2 & tmp4
tmp6 = x0
tmp7 = tmp6 < tmp1
tmp8 = tmp7 & tmp5
tmp9 = 3.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp8, tmp9, tmp10)
tmp12 = 0.0
tmp13 = tl.where(tmp7, tmp11, tmp12)
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp5, tmp13, tmp14)
tmp16 = tmp0 < tmp1
tmp17 = tmp6 >= tmp3
tmp18 = tmp17 & tmp16
tmp19 = 2.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp16 & tmp16
tmp23 = tmp6 >= tmp1
tmp24 = tmp6 < tmp3
tmp25 = tmp23 & tmp24
tmp26 = tmp25 & tmp22
tmp27 = 1.0
tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
tmp29 = tl.where(tmp26, tmp27, tmp28)
tmp30 = tmp16 & tmp22
tmp31 = tmp7 & tmp30
tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp33 = tl.where(tmp31, tmp12, tmp32)
tmp34 = tl.where(tmp7, tmp33, tmp12)
tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
tmp36 = tl.where(tmp30, tmp34, tmp35)
tmp37 = tl.where(tmp16, tmp36, tmp12)
tmp38 = tl.where(tmp25, tmp29, tmp37)
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp22, tmp38, tmp39)
tmp41 = tmp7 & tmp22
tmp42 = tl.where(tmp41, tmp12, tmp32)
tmp43 = tl.where(tmp7, tmp42, tmp12)
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
tmp45 = tl.where(tmp22, tmp43, tmp44)
tmp46 = tl.where(tmp16, tmp45, tmp12)
tmp47 = tl.where(tmp16, tmp40, tmp46)
tmp48 = tl.where(tmp17, tmp21, tmp47)
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
tmp50 = tl.where(tmp16, tmp48, tmp49)
tmp51 = tmp25 & tmp16
tmp52 = tl.where(tmp51, tmp27, tmp28)
tmp53 = tl.where(tmp25, tmp52, tmp46)
tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
tmp55 = tl.where(tmp16, tmp53, tmp54)
tmp56 = tmp7 & tmp16
tmp57 = tl.where(tmp56, tmp12, tmp32)
tmp58 = tl.where(tmp7, tmp57, tmp12)
tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
tmp60 = tl.where(tmp16, tmp58, tmp59)
tmp61 = tl.where(tmp16, tmp60, tmp12)
tmp62 = tl.where(tmp16, tmp55, tmp61)
tmp63 = tl.where(tmp16, tmp50, tmp62)
tmp64 = tl.where(tmp5, tmp15, tmp63)
tmp65 = tmp25 & tmp5
tmp66 = 4.0
tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
tmp68 = tl.where(tmp65, tmp66, tmp67)
tmp69 = tl.where(tmp25, tmp68, tmp64)
tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
tmp71 = tl.where(tmp5, tmp69, tmp70)
tmp72 = tl.where(tmp5, tmp71, tmp64)
tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/js/cjs7cpqsrlg6uha5jw3ddfb244ovf73mv7lc45xezom63apdnnpq.py
# Source Nodes: [setitem_8], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_8 => copy_8, lift_fresh_copy_10
triton_poi_fused_fill_lift_fresh_21 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[1024],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_21', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 819
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 273
x1 = (xindex // 273)
x2 = xindex
tmp55 = tl.load(in_ptr0 + (73710 + x2), xmask)
tmp0 = x0
tmp1 = tl.full([1], 270, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = 8.0
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = 270 + x1
tmp7 = tmp6 >= tmp1
tmp8 = tl.full([1], 266, tl.int64)
tmp9 = tmp0 >= tmp8
tmp10 = tmp0 < tmp1
tmp11 = tmp9 & tmp10
tmp12 = tmp11 & tmp7
tmp13 = 7.0
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp12, tmp13, tmp14)
tmp16 = tmp7 & tmp7
tmp17 = tmp0 < tmp8
tmp18 = tmp17 & tmp16
tmp19 = 6.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = 0.0
tmp23 = tl.where(tmp17, tmp21, tmp22)
tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
tmp25 = tl.where(tmp16, tmp23, tmp24)
tmp26 = tmp6 >= tmp8
tmp27 = tmp6 < tmp1
tmp28 = tmp26 & tmp27
tmp29 = tmp28 & tmp7
tmp30 = tmp2 & tmp29
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_ptr0 + (73710 + x2), tmp29 & xmask, other=0.0)
tmp35 = tl.where(tmp2, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp29, tmp35, tmp36)
tmp38 = tl.load(in_ptr0 + (73710 + x2), tmp7 & xmask, other=0.0)
tmp39 = tl.where(tmp28, tmp37, tmp38)
tmp40 = tl.where(tmp7, tmp25, tmp39)
tmp41 = tl.where(tmp11, tmp15, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp7, tmp41, tmp42)
tmp44 = tmp17 & tmp7
tmp45 = tl.where(tmp44, tmp19, tmp20)
tmp46 = tl.where(tmp17, tmp45, tmp22)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp7, tmp46, tmp47)
tmp49 = tmp2 & tmp28
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_ptr0 + (73710 + x2), tmp28 & xmask, other=0.0)
tmp52 = tl.where(tmp2, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp28, tmp52, tmp53)
tmp56 = tl.where(tmp28, tmp54, tmp55)
tmp57 = tl.where(tmp7, tmp48, tmp56)
tmp58 = tl.where(tmp7, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/u5/cu5hdqskxvnzge6kx3b3irwa2qsnzqulsyxqs3z77v5oyt3nwued.py
# Source Nodes: [setitem_5, setitem_6, setitem_7], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_5 => copy_5, lift_fresh_copy_7
# setitem_6 => copy_6, full_default_1, lift_fresh_copy_8
# setitem_7 => copy_7, lift_fresh_copy_9
triton_poi_fused_fill_lift_fresh_slice_22 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[131072],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_22', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 74529
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 273)
x2 = xindex
x0 = xindex % 273
tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
tmp0 = x1
tmp1 = tl.full([1], 270, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.load(in_ptr0 + ((-73710) + x2), tmp2 & xmask, other=0.0)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = x0
tmp7 = tl.full([1], 266, tl.int64)
tmp8 = tmp6 >= tmp7
tmp9 = tmp6 < tmp1
tmp10 = tmp8 & tmp9
tmp11 = tmp10 & tmp2
tmp12 = 7.0
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp11, tmp12, tmp13)
tmp15 = tmp2 & tmp2
tmp16 = tmp6 < tmp7
tmp17 = tmp16 & tmp15
tmp18 = 6.0
tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
tmp20 = tl.where(tmp17, tmp18, tmp19)
tmp21 = 0.0
tmp22 = tl.where(tmp16, tmp20, tmp21)
tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
tmp24 = tl.where(tmp15, tmp22, tmp23)
tmp25 = tmp0 >= tmp7
tmp26 = tmp0 < tmp1
tmp27 = tmp25 & tmp26
tmp28 = tmp27 & tmp2
tmp29 = tmp6 >= tmp1
tmp30 = tmp29 & tmp28
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
tmp35 = tl.where(tmp29, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp28, tmp35, tmp36)
tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
tmp39 = tl.where(tmp27, tmp37, tmp38)
tmp40 = tl.where(tmp2, tmp24, tmp39)
tmp41 = tl.where(tmp10, tmp14, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp2, tmp41, tmp42)
tmp44 = tmp16 & tmp2
tmp45 = tl.where(tmp44, tmp18, tmp19)
tmp46 = tl.where(tmp16, tmp45, tmp21)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp2, tmp46, tmp47)
tmp49 = tmp29 & tmp27
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
tmp52 = tl.where(tmp29, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp27, tmp52, tmp53)
tmp56 = tl.where(tmp27, tmp54, tmp55)
tmp57 = tl.where(tmp2, tmp48, tmp56)
tmp58 = tl.where(tmp2, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/7i/c7io46xdt3ptxo6ltcbckdtexmcbsso63trxe37lfrye2ltki4sg.py
# Source Nodes: [attn_6, attn_8, matmul_3], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_6 => add_15
# attn_8 => amax_1, div_3, exp_1, sub_6, sum_2
# matmul_3 => convert_element_type_43
triton_per_fused__softmax__to_copy_add_23 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[524288, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_23', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 298116
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 4
x2 = (xindex // 196)
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + ((7*(x2 % 39)) + (273*(r3 // 7)) + (1911*(x2 // 39)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp11 = tl.load(in_ptr3 + ((7*(x2 % 39)) + (273*(x0 // 7)) + (1911*(x2 // 39)) + (x0 % 7)), xmask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (4*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp12 = tmp10 - tmp11
tmp13 = 0.0
tmp14 = tmp12 == tmp13
tmp15 = tmp12 != tmp13
tmp16 = -100.0
tmp17 = tl.where(tmp15, tmp16, tmp12)
tmp18 = tl.where(tmp14, tmp13, tmp17)
tmp19 = tmp9 + tmp18
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
tmp24 = tmp19 - tmp23
tmp25 = tl_math.exp(tmp24)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
tmp28 = tl.where(rmask & xmask, tmp26, 0)
tmp29 = tl.sum(tmp28, 1)[:, None]
tmp30 = tmp25 / tmp29
tmp31 = tmp30.to(tl.float32)
tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/ag/caggjavokhkkhu3pmmyii7rbkrszu5jdeq7id5az4t423douybr2.py
# Source Nodes: [layer_norm_4, x_37, x_38], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_4 => add_19, add_20, mul_15, mul_16, rsqrt_4, sub_7, var_mean_4
# x_37 => add_18
# x_38 => convert_element_type_53
triton_red_fused__to_copy_add_native_layer_norm_24 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_24', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 + tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0)
tmp12 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp15.to(tl.float32)
tmp17 = tmp11 + tmp16
tmp18 = tmp17 - tmp8
tmp19 = 128.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr3 + (r1 + (128*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/lg/clgd3dty2q3ne3lnfcn3rmpdng74ev3dyoxwhsbz5berzpswiq3j.py
# Source Nodes: [x_37, x_43, x_out], Original ATen: [aten.add, aten.native_layer_norm]
# x_37 => add_18
# x_43 => add_22
# x_out => var_mean_6
triton_per_fused_add_native_layer_norm_25 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[131072, 128],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_native_layer_norm_25', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 73984
rnumel = 128
RBLOCK: tl.constexpr = 128
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (128*x0)), rmask & xmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (128*(((270 + (x0 % 272)) % 273) % 7)) + (896*(((270 + (x0 // 272)) % 273) % 7)) + (6272*(((270 + (x0 % 272)) % 273) // 7)) + (244608*(((270 + (x0 // 272)) % 273) // 7))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp7 = tl.load(in_ptr3 + (r1 + (128*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 + tmp5
tmp9 = tmp8.to(tl.float32)
tmp10 = tmp7 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp6 + tmp11
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, RBLOCK])
tmp15 = tl.where(rmask & xmask, tmp13, 0)
tmp16 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, 0)
tmp19 = tl.sum(tmp18, 1)[:, None]
tmp20 = tl.full([XBLOCK, 1], 128, tl.int32)
tmp21 = tmp20.to(tl.float32)
tmp22 = tmp19 / tmp21
tmp23 = tmp13 - tmp22
tmp24 = tmp23 * tmp23
tmp25 = tl.broadcast_to(tmp24, [XBLOCK, RBLOCK])
tmp27 = tl.where(rmask & xmask, tmp25, 0)
tmp28 = tl.sum(tmp27, 1)[:, None]
tl.store(out_ptr0 + (r1 + (128*x0)), tmp12, rmask & xmask)
tl.store(out_ptr1 + (x0), tmp22, xmask)
tl.store(out_ptr2 + (x0), tmp28, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/2t/c2tajzagxv474224mjjurbovktceibfbfuzbpihyo7vcvrjy4ggt.py
# Source Nodes: [x_47, x_48], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_47 => add_23, add_24, mul_20, mul_21, rsqrt_5, sub_8, var_mean_5
# x_48 => convert_element_type_65
triton_red_fused__to_copy_native_layer_norm_26 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[32768, 512],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_26', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 18496
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp32_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp32_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp32_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = r1
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 128, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp8 = tmp0 >= tmp3
tmp9 = tl.full([1, 1], 256, tl.int64)
tmp10 = tmp0 < tmp9
tmp11 = tmp8 & tmp10
tmp12 = tl.load(in_ptr0 + (34688 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp11 & xmask, eviction_policy='evict_last', other=0.0)
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp11, tmp12, tmp13)
tmp15 = tmp0 >= tmp9
tmp16 = tl.full([1, 1], 384, tl.int64)
tmp17 = tmp0 < tmp16
tmp18 = tmp15 & tmp17
tmp19 = tl.load(in_ptr0 + ((-128) + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp18 & xmask, eviction_policy='evict_last', other=0.0)
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp0 >= tmp16
tmp23 = tl.full([1, 1], 512, tl.int64)
tmp24 = tmp0 < tmp23
tmp25 = tl.load(in_ptr0 + (34560 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp22 & xmask, eviction_policy='evict_last', other=0.0)
tmp26 = tl.full(tmp25.shape, 0.0, tmp25.dtype)
tmp27 = tl.where(tmp22, tmp25, tmp26)
tmp28 = tl.where(tmp18, tmp21, tmp27)
tmp29 = tl.where(tmp11, tmp14, tmp28)
tmp30 = tl.where(tmp4, tmp7, tmp29)
tmp31 = tl.broadcast_to(tmp30, [XBLOCK, RBLOCK])
tmp32_mean_next, tmp32_m2_next, tmp32_weight_next = triton_helpers.welford_reduce(
tmp31, tmp32_mean, tmp32_m2, tmp32_weight, roffset == 0
)
tmp32_mean = tl.where(rmask & xmask, tmp32_mean_next, tmp32_mean)
tmp32_m2 = tl.where(rmask & xmask, tmp32_m2_next, tmp32_m2)
tmp32_weight = tl.where(rmask & xmask, tmp32_weight_next, tmp32_weight)
tmp32_tmp, tmp33_tmp, tmp34_tmp = triton_helpers.welford(
tmp32_mean, tmp32_m2, tmp32_weight, 1
)
tmp32 = tmp32_tmp[:, None]
tmp33 = tmp33_tmp[:, None]
tmp34 = tmp34_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp73 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp75 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp35 = r1
tmp36 = tl.full([1, 1], 0, tl.int64)
tmp37 = tmp35 >= tmp36
tmp38 = tl.full([1, 1], 128, tl.int64)
tmp39 = tmp35 < tmp38
tmp40 = tl.load(in_ptr0 + (r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp39 & xmask, eviction_policy='evict_first', other=0.0)
tmp41 = tl.full(tmp40.shape, 0.0, tmp40.dtype)
tmp42 = tl.where(tmp39, tmp40, tmp41)
tmp43 = tmp35 >= tmp38
tmp44 = tl.full([1, 1], 256, tl.int64)
tmp45 = tmp35 < tmp44
tmp46 = tmp43 & tmp45
tmp47 = tl.load(in_ptr0 + (34688 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp46 & xmask, eviction_policy='evict_first', other=0.0)
tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype)
tmp49 = tl.where(tmp46, tmp47, tmp48)
tmp50 = tmp35 >= tmp44
tmp51 = tl.full([1, 1], 384, tl.int64)
tmp52 = tmp35 < tmp51
tmp53 = tmp50 & tmp52
tmp54 = tl.load(in_ptr0 + ((-128) + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp53 & xmask, eviction_policy='evict_first', other=0.0)
tmp55 = tl.full(tmp54.shape, 0.0, tmp54.dtype)
tmp56 = tl.where(tmp53, tmp54, tmp55)
tmp57 = tmp35 >= tmp51
tmp58 = tl.full([1, 1], 512, tl.int64)
tmp59 = tmp35 < tmp58
tmp60 = tl.load(in_ptr0 + (34560 + r1 + (256*(x0 % 136)) + (69632*(x0 // 136))), rmask & tmp57 & xmask, eviction_policy='evict_first', other=0.0)
tmp61 = tl.full(tmp60.shape, 0.0, tmp60.dtype)
tmp62 = tl.where(tmp57, tmp60, tmp61)
tmp63 = tl.where(tmp53, tmp56, tmp62)
tmp64 = tl.where(tmp46, tmp49, tmp63)
tmp65 = tl.where(tmp39, tmp42, tmp64)
tmp66 = tmp65 - tmp32
tmp67 = 512.0
tmp68 = tmp33 / tmp67
tmp69 = 1e-05
tmp70 = tmp68 + tmp69
tmp71 = libdevice.rsqrt(tmp70)
tmp72 = tmp66 * tmp71
tmp74 = tmp72 * tmp73
tmp76 = tmp74 + tmp75
tmp77 = tmp76.to(tl.float32)
tl.store(out_ptr3 + (r1 + (512*x0)), tmp77, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/gb/cgbhaobp4brzi4bqej3wkl6sz477wwvivithdd33dfu6fcxdbvus.py
# Source Nodes: [x_48], Original ATen: [aten._to_copy]
# x_48 => convert_element_type_64
triton_poi_fused__to_copy_27 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[131072],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_27', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 131072
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/5c/c5cuctfarzmazdidadennbrjjffmqgssmpavewem5xeljfxn3dam.py
# Source Nodes: [x_50], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_50 => convert_element_type_70, var_mean_7
triton_per_fused__to_copy_native_layer_norm_28 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[32768, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_28', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 1, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 18496
XBLOCK: tl.constexpr = 1
rnumel = 256
RBLOCK: tl.constexpr = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
tmp4 = tl.where(rmask & xmask, tmp2, 0)
tmp5 = tl.broadcast_to(tmp2, [RBLOCK])
tmp7 = tl.where(rmask & xmask, tmp5, 0)
tmp8 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
tmp9 = tl.full([1], 256, tl.int32)
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp8 / tmp10
tmp12 = tmp2 - tmp11
tmp13 = tmp12 * tmp12
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp16 = tl.where(rmask & xmask, tmp14, 0)
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
tl.store(out_ptr0 + (x0), tmp11, xmask)
tl.store(out_ptr1 + (x0), tmp17, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/eu/ceuvro67qgbkaa4ffhjqtznjwlcpwtbvlkoxtlnkt25syksoejbq.py
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
# linear_9 => convert_element_type_73
triton_poi_fused__to_copy_29 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_29', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 256) % 49
x2 = (xindex // 12544)
x0 = xindex % 256
x3 = xindex
tmp0 = (7*(x2 // 20)) + (x1 // 7)
tmp1 = tl.full([1], 136, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (7*(x2 % 20)) + (x1 % 7)
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + (x0 + (256*(x1 % 7)) + (1792*(x2 % 20)) + (34816*(x1 // 7)) + (243712*(x2 // 20))), tmp5, other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tl.load(in_ptr1 + ((7*(x2 % 20)) + (136*(x1 // 7)) + (952*(x2 // 20)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
tmp9 = tmp7 - tmp8
tmp10 = tl.load(in_ptr2 + ((7*(x2 % 20)) + (136*(x1 // 7)) + (952*(x2 // 20)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
tmp11 = 256.0
tmp12 = tmp10 / tmp11
tmp13 = 1e-05
tmp14 = tmp12 + tmp13
tmp15 = libdevice.rsqrt(tmp14)
tmp16 = tmp9 * tmp15
tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp18 = tmp16 * tmp17
tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp20 = tmp18 + tmp19
tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
tmp22 = tl.where(tmp5, tmp20, tmp21)
tmp23 = tmp22.to(tl.float32)
tl.store(out_ptr0 + (x3), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/m7/cm7zh34yglzig2ihdzmpad6sfpayr2nhod7fd3bo4flse5huvirg.py
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
# linear_9 => convert_element_type_72
triton_poi_fused__to_copy_30 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[262144],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_30', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 196608
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/zc/czcprbvfrl4zyuc66lgt4ai6jmptvsogizmygs2osms6op7tomn3.py
# Source Nodes: [attn_10, q_5], Original ATen: [aten.clone, aten.mul]
# attn_10 => clone_30
# q_5 => mul_28
triton_poi_fused_clone_mul_31 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_31', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 8
x3 = (xindex // 12544)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (768*x1) + (37632*x3)), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = 0.1767766952966369
tmp5 = tmp3 * tmp4
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/e2/ce2cfq53qjwse2okyyoak6avwxvkjxqftsjmnigoql77qkqcbrrs.py
# Source Nodes: [attn_10], Original ATen: [aten.clone]
# attn_10 => clone_31
triton_poi_fused_clone_32 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[131072, 64], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_32', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 102400
xnumel = 49
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x3 = xindex
y2 = (yindex // 256)
y4 = yindex % 256
y0 = yindex % 32
y5 = (yindex // 32)
tmp0 = tl.load(in_ptr0 + (256 + y4 + (768*x3) + (37632*y2)), xmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (256 + y4), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/zy/czybsxl6sxk5pc55zhlieseau6jqy7zrc273232l2wqlpaymolf5.py
# Source Nodes: [attn_11, attn_12, matmul_5], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_11 => add_29
# attn_12 => amax_2, div_6, exp_2, sub_12, sum_3
# matmul_5 => convert_element_type_79
triton_per_fused__softmax__to_copy_add_33 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[262144, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_33', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 156800
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 8
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (8*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
tmp14 = tmp9 - tmp13
tmp15 = tl_math.exp(tmp14)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, 0)
tmp19 = tl.sum(tmp18, 1)[:, None]
tmp20 = tmp15 / tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/du/cdubf7lve33gjbq64z4vakpxvsuh4hfbnswhoxywkh4mxo4oybs2.py
# Source Nodes: [matmul_5], Original ATen: [aten.clone]
# matmul_5 => clone_34
triton_poi_fused_clone_34 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_34', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 8
x3 = (xindex // 12544)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (512 + x0 + (32*x2) + (768*x1) + (37632*x3)), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (512 + x0 + (32*x2)), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/id/cid673u2q44fnref24kvgclfxy4xefbos3kmpw6yx3axmr4zh7pe.py
# Source Nodes: [x_54], Original ATen: [aten.clone]
# x_54 => clone_35
triton_poi_fused_clone_35 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_35', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 8
x2 = (xindex // 256) % 49
x3 = (xindex // 12544)
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (12544*x3)), None).to(tl.float32)
tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/gd/cgd5tsctovridqlbm4jovksrawiwqsk2hdwl7hi35gmu56kuxwy7.py
# Source Nodes: [layer_norm_8, x_61, x_62], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_8 => add_31, add_32, convert_element_type_87, mul_29, mul_30, rsqrt_8, sub_13, var_mean_8
# x_61 => add_30
# x_62 => convert_element_type_90
triton_red_fused__to_copy_add_native_layer_norm_36 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[32768, 256],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_36', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 18496
rnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp11 + tmp15
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp17 - tmp8
tmp19 = 256.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr2 + (r1 + (256*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/vy/cvywxyrmptehw4ywu65ox3yzbfmlde2dw2jy3subo7aa44exaykq.py
# Source Nodes: [x_62], Original ATen: [aten._to_copy]
# x_62 => convert_element_type_89
triton_poi_fused__to_copy_37 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[262144],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_37', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 262144
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/nk/cnkppgvyewrogbxy2ewbwul5osjk3plsam5hnrnavfi2z6kblkyk.py
# Source Nodes: [x_63], Original ATen: [aten.gelu]
# x_63 => add_33, convert_element_type_94, convert_element_type_95, erf_2, mul_31, mul_32, mul_33
triton_poi_fused_gelu_38 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[33554432],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_38', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 18939904
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 1024
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = 0.5
tmp6 = tmp4 * tmp5
tmp7 = 0.7071067811865476
tmp8 = tmp4 * tmp7
tmp9 = libdevice.erf(tmp8)
tmp10 = 1.0
tmp11 = tmp9 + tmp10
tmp12 = tmp6 * tmp11
tmp13 = tmp12.to(tl.float32)
tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/mt/cmthogcbw6ytpaza2hq7tqd6jxon7w65cu5n467yareft33sxr42.py
# Source Nodes: [x_61, x_67, x_68], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_61 => add_30
# x_67 => add_34
# x_68 => convert_element_type_101, var_mean_9
triton_per_fused__to_copy_add_native_layer_norm_39 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[32768, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_39', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 18496
XBLOCK: tl.constexpr = 1
rnumel = 256
RBLOCK: tl.constexpr = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (256*((x0 % 136) % 7)) + (1792*((x0 // 136) % 7)) + (12544*((x0 % 136) // 7)) + (250880*(x0 // 952))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_out_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp10 = tmp5 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask & xmask, tmp12, 0)
tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
tmp17 = tl.where(rmask & xmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tl.full([1], 256, tl.int32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp18 / tmp20
tmp22 = tmp12 - tmp21
tmp23 = tmp22 * tmp22
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask & xmask, tmp24, 0)
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tl.store(in_out_ptr0 + (r1 + (256*x0)), tmp10, rmask & xmask)
tl.store(out_ptr0 + (x0), tmp21, xmask)
tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/nf/cnfaj3ajf5aqknow4tehtkbrxhtd7hg3eo5w2rldjul7cns6445j.py
# Source Nodes: [shifted_x_1, x_70], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x_1 => index_7, index_8
# x_70 => constant_pad_nd_3
triton_poi_fused_constant_pad_nd_roll_40 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_40', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = (xindex // 35840)
x1 = (xindex // 256) % 140
x0 = xindex % 256
x4 = xindex
tmp0 = (3 + x2) % 140
tmp1 = tl.full([1], 136, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (3 + x1) % 140
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + (x0 + (256*((3 + x1) % 140)) + (34816*((3 + x2) % 140))), tmp5, other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tl.load(in_ptr1 + ((136*((3 + x2) % 140)) + ((3 + x1) % 140)), tmp5, eviction_policy='evict_last', other=0.0)
tmp9 = tmp7 - tmp8
tmp10 = tl.load(in_ptr2 + ((136*((3 + x2) % 140)) + ((3 + x1) % 140)), tmp5, eviction_policy='evict_last', other=0.0)
tmp11 = 256.0
tmp12 = tmp10 / tmp11
tmp13 = 1e-05
tmp14 = tmp12 + tmp13
tmp15 = libdevice.rsqrt(tmp14)
tmp16 = tmp9 * tmp15
tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp18 = tmp16 * tmp17
tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp20 = tmp18 + tmp19
tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
tmp22 = tl.where(tmp5, tmp20, tmp21)
tl.store(out_ptr0 + (x4), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/wt/cwthqbmye7k6z3qf2osgm2rmstbobfi2vu64ljr6xtk6lwp2yfij.py
# Source Nodes: [linear_13], Original ATen: [aten._to_copy]
# linear_13 => convert_element_type_104
triton_poi_fused__to_copy_41 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8388608],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_41', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 5017600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 256
x1 = (xindex // 256) % 49
x2 = (xindex // 12544)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (256*(x1 % 7)) + (1792*(x2 % 20)) + (35840*(x1 // 7)) + (250880*(x2 // 20))), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/xb/cxbmto2fl747s6zog6a2j7e767nbetwcpmuu6g7lvjzgwbc622kq.py
# Source Nodes: [img_mask_1, setitem_10, setitem_11, setitem_12, setitem_13, setitem_9], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask_1 => full_1
# setitem_10 => copy_10, lift_fresh_copy_14
# setitem_11 => copy_11, lift_fresh_copy_15
# setitem_12 => copy_12, full_default_4, lift_fresh_copy_16
# setitem_13 => copy_13, lift_fresh_copy_17
# setitem_9 => copy_9, lift_fresh_copy_13
triton_poi_fused_fill_lift_fresh_slice_zeros_42 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[32768],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_42', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 19600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 140)
x0 = xindex % 140
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 133, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 137, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tmp2 & tmp4
tmp6 = x0
tmp7 = tmp6 < tmp1
tmp8 = tmp7 & tmp5
tmp9 = 3.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp8, tmp9, tmp10)
tmp12 = 0.0
tmp13 = tl.where(tmp7, tmp11, tmp12)
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp5, tmp13, tmp14)
tmp16 = tmp0 < tmp1
tmp17 = tmp6 >= tmp3
tmp18 = tmp17 & tmp16
tmp19 = 2.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp16 & tmp16
tmp23 = tmp6 >= tmp1
tmp24 = tmp6 < tmp3
tmp25 = tmp23 & tmp24
tmp26 = tmp25 & tmp22
tmp27 = 1.0
tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
tmp29 = tl.where(tmp26, tmp27, tmp28)
tmp30 = tmp16 & tmp22
tmp31 = tmp7 & tmp30
tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp33 = tl.where(tmp31, tmp12, tmp32)
tmp34 = tl.where(tmp7, tmp33, tmp12)
tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
tmp36 = tl.where(tmp30, tmp34, tmp35)
tmp37 = tl.where(tmp16, tmp36, tmp12)
tmp38 = tl.where(tmp25, tmp29, tmp37)
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp22, tmp38, tmp39)
tmp41 = tmp7 & tmp22
tmp42 = tl.where(tmp41, tmp12, tmp32)
tmp43 = tl.where(tmp7, tmp42, tmp12)
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
tmp45 = tl.where(tmp22, tmp43, tmp44)
tmp46 = tl.where(tmp16, tmp45, tmp12)
tmp47 = tl.where(tmp16, tmp40, tmp46)
tmp48 = tl.where(tmp17, tmp21, tmp47)
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
tmp50 = tl.where(tmp16, tmp48, tmp49)
tmp51 = tmp25 & tmp16
tmp52 = tl.where(tmp51, tmp27, tmp28)
tmp53 = tl.where(tmp25, tmp52, tmp46)
tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
tmp55 = tl.where(tmp16, tmp53, tmp54)
tmp56 = tmp7 & tmp16
tmp57 = tl.where(tmp56, tmp12, tmp32)
tmp58 = tl.where(tmp7, tmp57, tmp12)
tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
tmp60 = tl.where(tmp16, tmp58, tmp59)
tmp61 = tl.where(tmp16, tmp60, tmp12)
tmp62 = tl.where(tmp16, tmp55, tmp61)
tmp63 = tl.where(tmp16, tmp50, tmp62)
tmp64 = tl.where(tmp5, tmp15, tmp63)
tmp65 = tmp25 & tmp5
tmp66 = 4.0
tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
tmp68 = tl.where(tmp65, tmp66, tmp67)
tmp69 = tl.where(tmp25, tmp68, tmp64)
tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
tmp71 = tl.where(tmp5, tmp69, tmp70)
tmp72 = tl.where(tmp5, tmp71, tmp64)
tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/7z/c7ze75vhgc4kguv4mmbm47unz47zkwxnsdyj2jnp2sgkef4vzr75.py
# Source Nodes: [setitem_17], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_17 => copy_17, lift_fresh_copy_21
triton_poi_fused_fill_lift_fresh_43 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[512],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_43', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 420
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 140
x1 = (xindex // 140)
x2 = xindex
tmp55 = tl.load(in_ptr0 + (19180 + x2), xmask)
tmp0 = x0
tmp1 = tl.full([1], 137, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = 8.0
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = 137 + x1
tmp7 = tmp6 >= tmp1
tmp8 = tl.full([1], 133, tl.int64)
tmp9 = tmp0 >= tmp8
tmp10 = tmp0 < tmp1
tmp11 = tmp9 & tmp10
tmp12 = tmp11 & tmp7
tmp13 = 7.0
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp12, tmp13, tmp14)
tmp16 = tmp7 & tmp7
tmp17 = tmp0 < tmp8
tmp18 = tmp17 & tmp16
tmp19 = 6.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = 0.0
tmp23 = tl.where(tmp17, tmp21, tmp22)
tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
tmp25 = tl.where(tmp16, tmp23, tmp24)
tmp26 = tmp6 >= tmp8
tmp27 = tmp6 < tmp1
tmp28 = tmp26 & tmp27
tmp29 = tmp28 & tmp7
tmp30 = tmp2 & tmp29
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_ptr0 + (19180 + x2), tmp29 & xmask, other=0.0)
tmp35 = tl.where(tmp2, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp29, tmp35, tmp36)
tmp38 = tl.load(in_ptr0 + (19180 + x2), tmp7 & xmask, other=0.0)
tmp39 = tl.where(tmp28, tmp37, tmp38)
tmp40 = tl.where(tmp7, tmp25, tmp39)
tmp41 = tl.where(tmp11, tmp15, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp7, tmp41, tmp42)
tmp44 = tmp17 & tmp7
tmp45 = tl.where(tmp44, tmp19, tmp20)
tmp46 = tl.where(tmp17, tmp45, tmp22)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp7, tmp46, tmp47)
tmp49 = tmp2 & tmp28
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_ptr0 + (19180 + x2), tmp28 & xmask, other=0.0)
tmp52 = tl.where(tmp2, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp28, tmp52, tmp53)
tmp56 = tl.where(tmp28, tmp54, tmp55)
tmp57 = tl.where(tmp7, tmp48, tmp56)
tmp58 = tl.where(tmp7, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/7z/c7zt444wgwxirjwbsyxeqljxxhoqggbgugx7kf4kefsaruhixphf.py
# Source Nodes: [setitem_14, setitem_15, setitem_16], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_14 => copy_14, lift_fresh_copy_18
# setitem_15 => copy_15, full_default_5, lift_fresh_copy_19
# setitem_16 => copy_16, lift_fresh_copy_20
triton_poi_fused_fill_lift_fresh_slice_44 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[32768],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_44', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 19600
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 140)
x2 = xindex
x0 = xindex % 140
tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
tmp0 = x1
tmp1 = tl.full([1], 137, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.load(in_ptr0 + ((-19180) + x2), tmp2 & xmask, other=0.0)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = x0
tmp7 = tl.full([1], 133, tl.int64)
tmp8 = tmp6 >= tmp7
tmp9 = tmp6 < tmp1
tmp10 = tmp8 & tmp9
tmp11 = tmp10 & tmp2
tmp12 = 7.0
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp11, tmp12, tmp13)
tmp15 = tmp2 & tmp2
tmp16 = tmp6 < tmp7
tmp17 = tmp16 & tmp15
tmp18 = 6.0
tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
tmp20 = tl.where(tmp17, tmp18, tmp19)
tmp21 = 0.0
tmp22 = tl.where(tmp16, tmp20, tmp21)
tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
tmp24 = tl.where(tmp15, tmp22, tmp23)
tmp25 = tmp0 >= tmp7
tmp26 = tmp0 < tmp1
tmp27 = tmp25 & tmp26
tmp28 = tmp27 & tmp2
tmp29 = tmp6 >= tmp1
tmp30 = tmp29 & tmp28
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
tmp35 = tl.where(tmp29, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp28, tmp35, tmp36)
tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
tmp39 = tl.where(tmp27, tmp37, tmp38)
tmp40 = tl.where(tmp2, tmp24, tmp39)
tmp41 = tl.where(tmp10, tmp14, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp2, tmp41, tmp42)
tmp44 = tmp16 & tmp2
tmp45 = tl.where(tmp44, tmp18, tmp19)
tmp46 = tl.where(tmp16, tmp45, tmp21)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp2, tmp46, tmp47)
tmp49 = tmp29 & tmp27
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
tmp52 = tl.where(tmp29, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp27, tmp52, tmp53)
tmp56 = tl.where(tmp27, tmp54, tmp55)
tmp57 = tl.where(tmp2, tmp48, tmp56)
tmp58 = tl.where(tmp2, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/bh/cbhjqxz5ihwt4mxdvd3tbvekognxv56zu43p77uwnzktcpmbcvo7.py
# Source Nodes: [attn_16, attn_18, matmul_7], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_16 => add_40
# attn_18 => amax_3, div_7, exp_3, sub_15, sum_4
# matmul_7 => convert_element_type_110
triton_per_fused__softmax__to_copy_add_45 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[262144, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_45', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 156800
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 8
x2 = (xindex // 392)
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + ((7*(x2 % 20)) + (140*(r3 // 7)) + (980*(x2 // 20)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp11 = tl.load(in_ptr3 + ((7*(x2 % 20)) + (140*(x0 // 7)) + (980*(x2 // 20)) + (x0 % 7)), xmask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (8*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp12 = tmp10 - tmp11
tmp13 = 0.0
tmp14 = tmp12 == tmp13
tmp15 = tmp12 != tmp13
tmp16 = -100.0
tmp17 = tl.where(tmp15, tmp16, tmp12)
tmp18 = tl.where(tmp14, tmp13, tmp17)
tmp19 = tmp9 + tmp18
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
tmp24 = tmp19 - tmp23
tmp25 = tl_math.exp(tmp24)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
tmp28 = tl.where(rmask & xmask, tmp26, 0)
tmp29 = tl.sum(tmp28, 1)[:, None]
tmp30 = tmp25 / tmp29
tmp31 = tmp30.to(tl.float32)
tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/f2/cf2npb2ejacogxpb6cdcfuv4z5bmltrg2zqsllg7hjzccj546id4.py
# Source Nodes: [layer_norm_10, x_80, x_81], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_10 => add_44, add_45, convert_element_type_118, mul_37, mul_38, rsqrt_10, sub_16, var_mean_10
# x_80 => add_43
# x_81 => convert_element_type_121
triton_red_fused__to_copy_add_native_layer_norm_46 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[32768, 256],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_46', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 18496
rnumel = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp11 + tmp15
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp17 - tmp8
tmp19 = 256.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr3 + (r1 + (256*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/qa/cqaaamnieqtt5cxfgvrjon2locvummefm6c6uw5vo5jw43ghlrb6.py
# Source Nodes: [x_80, x_86, x_out_1], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_80 => add_43
# x_86 => add_47
# x_out_1 => convert_element_type_137, var_mean_12
triton_per_fused__to_copy_add_native_layer_norm_47 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[32768, 256],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_47', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 18496
XBLOCK: tl.constexpr = 1
rnumel = 256
RBLOCK: tl.constexpr = 256
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (256*(((137 + (x0 % 136)) % 140) % 7)) + (1792*(((137 + (x0 // 136)) % 140) % 7)) + (12544*(((137 + (x0 % 136)) % 140) // 7)) + (250880*(((137 + (x0 // 136)) % 140) // 7))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_out_ptr0 + (r1 + (256*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp10 = tmp5 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask & xmask, tmp12, 0)
tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
tmp17 = tl.where(rmask & xmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tl.full([1], 256, tl.int32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp18 / tmp20
tmp22 = tmp12 - tmp21
tmp23 = tmp22 * tmp22
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask & xmask, tmp24, 0)
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tl.store(in_out_ptr0 + (r1 + (256*x0)), tmp10, rmask & xmask)
tl.store(out_ptr0 + (x0), tmp21, xmask)
tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/vi/cvi4wbzakyffipye66rug6dumlcgc7kxs22zwyjq3yzsdisedy7v.py
# Source Nodes: [x_90, x_91], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_90 => add_48, add_49, convert_element_type_132, mul_42, mul_43, rsqrt_11, sub_17, var_mean_11
# x_91 => convert_element_type_134
triton_red_fused__to_copy_native_layer_norm_48 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 1024],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_native_layer_norm_48', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 10, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4624
rnumel = 1024
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp33_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp33_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp33_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = r1
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 256, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp8 = tmp0 >= tmp3
tmp9 = tl.full([1, 1], 512, tl.int64)
tmp10 = tmp0 < tmp9
tmp11 = tmp8 & tmp10
tmp12 = tl.load(in_ptr0 + (34560 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp11 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp11, tmp12, tmp13)
tmp15 = tmp0 >= tmp9
tmp16 = tl.full([1, 1], 768, tl.int64)
tmp17 = tmp0 < tmp16
tmp18 = tmp15 & tmp17
tmp19 = tl.load(in_ptr0 + ((-256) + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp18 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp0 >= tmp16
tmp23 = tl.full([1, 1], 1024, tl.int64)
tmp24 = tmp0 < tmp23
tmp25 = tl.load(in_ptr0 + (34304 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp22 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp26 = tl.full(tmp25.shape, 0.0, tmp25.dtype)
tmp27 = tl.where(tmp22, tmp25, tmp26)
tmp28 = tl.where(tmp18, tmp21, tmp27)
tmp29 = tl.where(tmp11, tmp14, tmp28)
tmp30 = tl.where(tmp4, tmp7, tmp29)
tmp31 = tmp30.to(tl.float32)
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, RBLOCK])
tmp33_mean_next, tmp33_m2_next, tmp33_weight_next = triton_helpers.welford_reduce(
tmp32, tmp33_mean, tmp33_m2, tmp33_weight, roffset == 0
)
tmp33_mean = tl.where(rmask & xmask, tmp33_mean_next, tmp33_mean)
tmp33_m2 = tl.where(rmask & xmask, tmp33_m2_next, tmp33_m2)
tmp33_weight = tl.where(rmask & xmask, tmp33_weight_next, tmp33_weight)
tmp33_tmp, tmp34_tmp, tmp35_tmp = triton_helpers.welford(
tmp33_mean, tmp33_m2, tmp33_weight, 1
)
tmp33 = tmp33_tmp[:, None]
tmp34 = tmp34_tmp[:, None]
tmp35 = tmp35_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp75 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp77 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp36 = r1
tmp37 = tl.full([1, 1], 0, tl.int64)
tmp38 = tmp36 >= tmp37
tmp39 = tl.full([1, 1], 256, tl.int64)
tmp40 = tmp36 < tmp39
tmp41 = tl.load(in_ptr0 + (r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp40 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp40, tmp41, tmp42)
tmp44 = tmp36 >= tmp39
tmp45 = tl.full([1, 1], 512, tl.int64)
tmp46 = tmp36 < tmp45
tmp47 = tmp44 & tmp46
tmp48 = tl.load(in_ptr0 + (34560 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp47 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
tmp50 = tl.where(tmp47, tmp48, tmp49)
tmp51 = tmp36 >= tmp45
tmp52 = tl.full([1, 1], 768, tl.int64)
tmp53 = tmp36 < tmp52
tmp54 = tmp51 & tmp53
tmp55 = tl.load(in_ptr0 + ((-256) + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp54 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp56 = tl.full(tmp55.shape, 0.0, tmp55.dtype)
tmp57 = tl.where(tmp54, tmp55, tmp56)
tmp58 = tmp36 >= tmp52
tmp59 = tl.full([1, 1], 1024, tl.int64)
tmp60 = tmp36 < tmp59
tmp61 = tl.load(in_ptr0 + (34304 + r1 + (512*(x0 % 68)) + (69632*(x0 // 68))), rmask & tmp58 & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp62 = tl.full(tmp61.shape, 0.0, tmp61.dtype)
tmp63 = tl.where(tmp58, tmp61, tmp62)
tmp64 = tl.where(tmp54, tmp57, tmp63)
tmp65 = tl.where(tmp47, tmp50, tmp64)
tmp66 = tl.where(tmp40, tmp43, tmp65)
tmp67 = tmp66.to(tl.float32)
tmp68 = tmp67 - tmp33
tmp69 = 1024.0
tmp70 = tmp34 / tmp69
tmp71 = 1e-05
tmp72 = tmp70 + tmp71
tmp73 = libdevice.rsqrt(tmp72)
tmp74 = tmp68 * tmp73
tmp76 = tmp74 * tmp75
tmp78 = tmp76 + tmp77
tmp79 = tmp78.to(tl.float32)
tl.store(out_ptr3 + (r1 + (1024*x0)), tmp79, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/2b/c2br45eimgqwhlwhtljpzxkxy33i7wtdmeqadpkf2nk76k6jqfxd.py
# Source Nodes: [x_91], Original ATen: [aten._to_copy]
# x_91 => convert_element_type_133
triton_poi_fused__to_copy_49 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[524288],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_49', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 524288
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/um/cumnariyjk3qtyk262n4xg2k375rwdvxooxdl5onswzasc5zkkpc.py
# Source Nodes: [x_93], Original ATen: [aten._to_copy, aten.native_layer_norm]
# x_93 => convert_element_type_140, var_mean_13
triton_per_fused__to_copy_native_layer_norm_50 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_native_layer_norm_50', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 1, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 4624
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
tmp4 = tl.where(rmask & xmask, tmp2, 0)
tmp5 = tl.broadcast_to(tmp2, [RBLOCK])
tmp7 = tl.where(rmask & xmask, tmp5, 0)
tmp8 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))
tmp9 = tl.full([1], 512, tl.int32)
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp8 / tmp10
tmp12 = tmp2 - tmp11
tmp13 = tmp12 * tmp12
tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
tmp16 = tl.where(rmask & xmask, tmp14, 0)
tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
tl.store(out_ptr0 + (x0), tmp11, xmask)
tl.store(out_ptr1 + (x0), tmp17, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/g3/cg3khvtgpcqpjxr7l476vsoytoskgybzulj6ccz6csy5y4spmsvy.py
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
# linear_18 => convert_element_type_143
triton_poi_fused__to_copy_51 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_51', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 512) % 49
x2 = (xindex // 25088)
x0 = xindex % 512
x3 = xindex
tmp0 = (7*(x2 // 10)) + (x1 // 7)
tmp1 = tl.full([1], 68, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (7*(x2 % 10)) + (x1 % 7)
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + (x0 + (512*(x1 % 7)) + (3584*(x2 % 10)) + (34816*(x1 // 7)) + (243712*(x2 // 10))), tmp5, other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tl.load(in_ptr1 + ((7*(x2 % 10)) + (68*(x1 // 7)) + (476*(x2 // 10)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
tmp9 = tmp7 - tmp8
tmp10 = tl.load(in_ptr2 + ((7*(x2 % 10)) + (68*(x1 // 7)) + (476*(x2 // 10)) + (x1 % 7)), tmp5, eviction_policy='evict_last', other=0.0)
tmp11 = 512.0
tmp12 = tmp10 / tmp11
tmp13 = 1e-05
tmp14 = tmp12 + tmp13
tmp15 = libdevice.rsqrt(tmp14)
tmp16 = tmp9 * tmp15
tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp18 = tmp16 * tmp17
tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp20 = tmp18 + tmp19
tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
tmp22 = tl.where(tmp5, tmp20, tmp21)
tmp23 = tmp22.to(tl.float32)
tl.store(out_ptr0 + (x3), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/7h/c7h2rh5tddfvhz4ra3gissuywjohe7jgaicl73dtmurv236yclh4.py
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
# linear_18 => convert_element_type_142
triton_poi_fused__to_copy_52 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_52', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 786432
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/lj/cljyipqxaos2lyhrcrywearlyzfxb2hz6o2kxzgjx4qc6pmpeudj.py
# Source Nodes: [attn_20, q_9], Original ATen: [aten.clone, aten.mul]
# attn_20 => clone_56
# q_9 => mul_50
triton_poi_fused_clone_mul_53 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_mul_53', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 16
x3 = (xindex // 25088)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1536*x1) + (75264*x3)), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0 + (32*x2)), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = 0.1767766952966369
tmp5 = tmp3 * tmp4
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/av/cavjexsophvnr2bav6uoax3inj3rzrgvwxxuzbozsbhi5b7t7wxk.py
# Source Nodes: [attn_20], Original ATen: [aten.clone]
# attn_20 => clone_57
triton_poi_fused_clone_54 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[65536, 64], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32', 4: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_54', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 51200
xnumel = 49
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x3 = xindex
y2 = (yindex // 512)
y4 = yindex % 512
y0 = yindex % 32
y5 = (yindex // 32)
tmp0 = tl.load(in_ptr0 + (512 + y4 + (1536*x3) + (75264*y2)), xmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (512 + y4), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x3 + (49*y0) + (1600*y5)), tmp3, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/2m/c2mbsjkqnezozcymlpyhsrwbu3d475l7jp7llp3av6ycshavi5cp.py
# Source Nodes: [attn_21, attn_22, matmul_9], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_21 => add_54
# attn_22 => amax_4, div_10, exp_4, sub_21, sum_5
# matmul_9 => convert_element_type_149
triton_per_fused__softmax__to_copy_add_55 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[131072, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp16', 4: 'i32', 5: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_55', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 78400
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 16
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (16*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp10 = tl.broadcast_to(tmp9, [XBLOCK, RBLOCK])
tmp12 = tl.where(rmask & xmask, tmp10, float("-inf"))
tmp13 = triton_helpers.max2(tmp12, 1)[:, None]
tmp14 = tmp9 - tmp13
tmp15 = tl_math.exp(tmp14)
tmp16 = tl.broadcast_to(tmp15, [XBLOCK, RBLOCK])
tmp18 = tl.where(rmask & xmask, tmp16, 0)
tmp19 = tl.sum(tmp18, 1)[:, None]
tmp20 = tmp15 / tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(out_ptr2 + (r3 + (49*x0) + (2432*x5)), tmp21, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/lk/clkuywa7pywq2gszyicfrwhbcan7mq2hjz7jsjx7fbjhf2kv6ywf.py
# Source Nodes: [matmul_9], Original ATen: [aten.clone]
# matmul_9 => clone_60
triton_poi_fused_clone_56 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp16', 3: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_56', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 49
x2 = (xindex // 1568) % 16
x3 = (xindex // 25088)
x4 = xindex % 1568
x5 = (xindex // 1568)
tmp0 = tl.load(in_ptr0 + (1024 + x0 + (32*x2) + (1536*x1) + (75264*x3)), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (1024 + x0 + (32*x2)), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x4 + (1600*x5)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/dg/cdg5noei5zaw6qlukk4dp5ch354ru3lymtetv3ljqnq626zlqxdh.py
# Source Nodes: [x_97], Original ATen: [aten.clone]
# x_97 => clone_61
triton_poi_fused_clone_57 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_57', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 32
x1 = (xindex // 32) % 16
x2 = (xindex // 512) % 49
x3 = (xindex // 25088)
x4 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (32*x2) + (1568*x1) + (25088*x3)), None).to(tl.float32)
tl.store(out_ptr0 + (x4), tmp0, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/t2/ct2dp3mhidfvqpjtzrn2dhky572jcvj66o5hnov3dfawuo5qorjc.py
# Source Nodes: [layer_norm_14, x_104, x_105], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_14 => add_56, add_57, convert_element_type_157, mul_51, mul_52, rsqrt_14, sub_22, var_mean_14
# x_104 => add_55
# x_105 => convert_element_type_160
triton_red_fused__to_copy_add_native_layer_norm_58 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_58', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4624
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp11 + tmp15
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp17 - tmp8
tmp19 = 512.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr2 + (r1 + (512*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/np/cnp6uoxnm72l4yszb432o46mj4nxiei7ycexyxlr2x6ni274sotg.py
# Source Nodes: [x_105], Original ATen: [aten._to_copy]
# x_105 => convert_element_type_159
triton_poi_fused__to_copy_59 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[1048576],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_59', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 1048576
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/nq/cnqdp7jcikbji2dqz7a5oy3qn777ber3u2a6msqynur5lblmly63.py
# Source Nodes: [x_106], Original ATen: [aten.gelu]
# x_106 => add_58, convert_element_type_164, convert_element_type_165, erf_4, mul_53, mul_54, mul_55
triton_poi_fused_gelu_60 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[16777216],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_gelu_60', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 9469952
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x0 = xindex % 2048
tmp0 = tl.load(in_out_ptr0 + (x2), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp2 = tmp1.to(tl.float32)
tmp3 = tmp0 + tmp2
tmp4 = tmp3.to(tl.float32)
tmp5 = 0.5
tmp6 = tmp4 * tmp5
tmp7 = 0.7071067811865476
tmp8 = tmp4 * tmp7
tmp9 = libdevice.erf(tmp8)
tmp10 = 1.0
tmp11 = tmp9 + tmp10
tmp12 = tmp6 * tmp11
tmp13 = tmp12.to(tl.float32)
tl.store(in_out_ptr0 + (x2), tmp13, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/qh/cqhfnjwnhzm6zoxn4h6l4ih7jzgbrmhwlxlccqxcasplc2pj2oio.py
# Source Nodes: [x_104, x_110, x_111], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_104 => add_55
# x_110 => add_59
# x_111 => convert_element_type_171, var_mean_15
triton_per_fused__to_copy_add_native_layer_norm_61 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_61', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 4624
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (512*((x0 % 68) % 7)) + (3584*((x0 // 68) % 7)) + (25088*((x0 % 68) // 7)) + (250880*(x0 // 476))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_out_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp10 = tmp5 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask & xmask, tmp12, 0)
tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
tmp17 = tl.where(rmask & xmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tl.full([1], 512, tl.int32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp18 / tmp20
tmp22 = tmp12 - tmp21
tmp23 = tmp22 * tmp22
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask & xmask, tmp24, 0)
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tl.store(in_out_ptr0 + (r1 + (512*x0)), tmp10, rmask & xmask)
tl.store(out_ptr0 + (x0), tmp21, xmask)
tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/d6/cd647osl6hgyj5l4klpv44ixju2cu5kkyigf4txwmejkih3db2aj.py
# Source Nodes: [shifted_x_2, x_113], Original ATen: [aten.constant_pad_nd, aten.roll]
# shifted_x_2 => index_13, index_14
# x_113 => constant_pad_nd_5
triton_poi_fused_constant_pad_nd_roll_62 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_constant_pad_nd_roll_62', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = (xindex // 35840)
x1 = (xindex // 512) % 70
x0 = xindex % 512
x4 = xindex
tmp0 = (3 + x2) % 70
tmp1 = tl.full([1], 68, tl.int64)
tmp2 = tmp0 < tmp1
tmp3 = (3 + x1) % 70
tmp4 = tmp3 < tmp1
tmp5 = tmp2 & tmp4
tmp6 = tl.load(in_ptr0 + (x0 + (512*((3 + x1) % 70)) + (34816*((3 + x2) % 70))), tmp5, other=0.0).to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tl.load(in_ptr1 + ((68*((3 + x2) % 70)) + ((3 + x1) % 70)), tmp5, eviction_policy='evict_last', other=0.0)
tmp9 = tmp7 - tmp8
tmp10 = tl.load(in_ptr2 + ((68*((3 + x2) % 70)) + ((3 + x1) % 70)), tmp5, eviction_policy='evict_last', other=0.0)
tmp11 = 512.0
tmp12 = tmp10 / tmp11
tmp13 = 1e-05
tmp14 = tmp12 + tmp13
tmp15 = libdevice.rsqrt(tmp14)
tmp16 = tmp9 * tmp15
tmp17 = tl.load(in_ptr3 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp18 = tmp16 * tmp17
tmp19 = tl.load(in_ptr4 + (x0), tmp5, eviction_policy='evict_last', other=0.0)
tmp20 = tmp18 + tmp19
tmp21 = tl.full(tmp20.shape, 0.0, tmp20.dtype)
tmp22 = tl.where(tmp5, tmp20, tmp21)
tl.store(out_ptr0 + (x4), tmp22, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/kf/ckfhysrliuileref3yojypgpmuck4hqbsxqymr3ij7j54y6ysdwm.py
# Source Nodes: [linear_22], Original ATen: [aten._to_copy]
# linear_22 => convert_element_type_174
triton_poi_fused__to_copy_63 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[4194304],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp16', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_63', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2508800
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 512
x1 = (xindex // 512) % 49
x2 = (xindex // 25088)
x3 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (512*(x1 % 7)) + (3584*(x2 % 10)) + (35840*(x1 // 7)) + (250880*(x2 // 10))), None)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x3), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/la/cla235hx3mubj6sgpoukkyfixg6bxgnvd6ngdgy3zwatq5lyqmcp.py
# Source Nodes: [img_mask_2, setitem_18, setitem_19, setitem_20, setitem_21, setitem_22], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
# img_mask_2 => full_2
# setitem_18 => copy_18, lift_fresh_copy_24
# setitem_19 => copy_19, lift_fresh_copy_25
# setitem_20 => copy_20, lift_fresh_copy_26
# setitem_21 => copy_21, full_default_8, lift_fresh_copy_27
# setitem_22 => copy_22, lift_fresh_copy_28
triton_poi_fused_fill_lift_fresh_slice_zeros_64 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8192],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0,), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_zeros_64', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4900
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 70)
x0 = xindex % 70
x2 = xindex
tmp0 = x1
tmp1 = tl.full([1], 63, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 67, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tmp2 & tmp4
tmp6 = x0
tmp7 = tmp6 < tmp1
tmp8 = tmp7 & tmp5
tmp9 = 3.0
tmp10 = tl.full(tmp9.shape, 0.0, tmp9.dtype)
tmp11 = tl.where(tmp8, tmp9, tmp10)
tmp12 = 0.0
tmp13 = tl.where(tmp7, tmp11, tmp12)
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp5, tmp13, tmp14)
tmp16 = tmp0 < tmp1
tmp17 = tmp6 >= tmp3
tmp18 = tmp17 & tmp16
tmp19 = 2.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = tmp16 & tmp16
tmp23 = tmp6 >= tmp1
tmp24 = tmp6 < tmp3
tmp25 = tmp23 & tmp24
tmp26 = tmp25 & tmp22
tmp27 = 1.0
tmp28 = tl.full(tmp27.shape, 0.0, tmp27.dtype)
tmp29 = tl.where(tmp26, tmp27, tmp28)
tmp30 = tmp16 & tmp22
tmp31 = tmp7 & tmp30
tmp32 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp33 = tl.where(tmp31, tmp12, tmp32)
tmp34 = tl.where(tmp7, tmp33, tmp12)
tmp35 = tl.full(tmp34.shape, 0.0, tmp34.dtype)
tmp36 = tl.where(tmp30, tmp34, tmp35)
tmp37 = tl.where(tmp16, tmp36, tmp12)
tmp38 = tl.where(tmp25, tmp29, tmp37)
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp22, tmp38, tmp39)
tmp41 = tmp7 & tmp22
tmp42 = tl.where(tmp41, tmp12, tmp32)
tmp43 = tl.where(tmp7, tmp42, tmp12)
tmp44 = tl.full(tmp43.shape, 0.0, tmp43.dtype)
tmp45 = tl.where(tmp22, tmp43, tmp44)
tmp46 = tl.where(tmp16, tmp45, tmp12)
tmp47 = tl.where(tmp16, tmp40, tmp46)
tmp48 = tl.where(tmp17, tmp21, tmp47)
tmp49 = tl.full(tmp48.shape, 0.0, tmp48.dtype)
tmp50 = tl.where(tmp16, tmp48, tmp49)
tmp51 = tmp25 & tmp16
tmp52 = tl.where(tmp51, tmp27, tmp28)
tmp53 = tl.where(tmp25, tmp52, tmp46)
tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
tmp55 = tl.where(tmp16, tmp53, tmp54)
tmp56 = tmp7 & tmp16
tmp57 = tl.where(tmp56, tmp12, tmp32)
tmp58 = tl.where(tmp7, tmp57, tmp12)
tmp59 = tl.full(tmp58.shape, 0.0, tmp58.dtype)
tmp60 = tl.where(tmp16, tmp58, tmp59)
tmp61 = tl.where(tmp16, tmp60, tmp12)
tmp62 = tl.where(tmp16, tmp55, tmp61)
tmp63 = tl.where(tmp16, tmp50, tmp62)
tmp64 = tl.where(tmp5, tmp15, tmp63)
tmp65 = tmp25 & tmp5
tmp66 = 4.0
tmp67 = tl.full(tmp66.shape, 0.0, tmp66.dtype)
tmp68 = tl.where(tmp65, tmp66, tmp67)
tmp69 = tl.where(tmp25, tmp68, tmp64)
tmp70 = tl.full(tmp69.shape, 0.0, tmp69.dtype)
tmp71 = tl.where(tmp5, tmp69, tmp70)
tmp72 = tl.where(tmp5, tmp71, tmp64)
tl.store(in_out_ptr0 + (x2), tmp72, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/iy/ciyf4skxx4pkmks5oyxbr7ildelf2g5isnbw6b6bh6z7lglqwl7j.py
# Source Nodes: [setitem_26], Original ATen: [aten.fill, aten.lift_fresh]
# setitem_26 => copy_26, lift_fresh_copy_32
triton_poi_fused_fill_lift_fresh_65 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[256],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_65', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 210
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 70
x1 = (xindex // 70)
x2 = xindex
tmp55 = tl.load(in_ptr0 + (4690 + x2), xmask)
tmp0 = x0
tmp1 = tl.full([1], 67, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = 8.0
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = 67 + x1
tmp7 = tmp6 >= tmp1
tmp8 = tl.full([1], 63, tl.int64)
tmp9 = tmp0 >= tmp8
tmp10 = tmp0 < tmp1
tmp11 = tmp9 & tmp10
tmp12 = tmp11 & tmp7
tmp13 = 7.0
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp12, tmp13, tmp14)
tmp16 = tmp7 & tmp7
tmp17 = tmp0 < tmp8
tmp18 = tmp17 & tmp16
tmp19 = 6.0
tmp20 = tl.full(tmp19.shape, 0.0, tmp19.dtype)
tmp21 = tl.where(tmp18, tmp19, tmp20)
tmp22 = 0.0
tmp23 = tl.where(tmp17, tmp21, tmp22)
tmp24 = tl.full(tmp23.shape, 0.0, tmp23.dtype)
tmp25 = tl.where(tmp16, tmp23, tmp24)
tmp26 = tmp6 >= tmp8
tmp27 = tmp6 < tmp1
tmp28 = tmp26 & tmp27
tmp29 = tmp28 & tmp7
tmp30 = tmp2 & tmp29
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_ptr0 + (4690 + x2), tmp29 & xmask, other=0.0)
tmp35 = tl.where(tmp2, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp29, tmp35, tmp36)
tmp38 = tl.load(in_ptr0 + (4690 + x2), tmp7 & xmask, other=0.0)
tmp39 = tl.where(tmp28, tmp37, tmp38)
tmp40 = tl.where(tmp7, tmp25, tmp39)
tmp41 = tl.where(tmp11, tmp15, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp7, tmp41, tmp42)
tmp44 = tmp17 & tmp7
tmp45 = tl.where(tmp44, tmp19, tmp20)
tmp46 = tl.where(tmp17, tmp45, tmp22)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp7, tmp46, tmp47)
tmp49 = tmp2 & tmp28
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_ptr0 + (4690 + x2), tmp28 & xmask, other=0.0)
tmp52 = tl.where(tmp2, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp28, tmp52, tmp53)
tmp56 = tl.where(tmp28, tmp54, tmp55)
tmp57 = tl.where(tmp7, tmp48, tmp56)
tmp58 = tl.where(tmp7, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/wo/cwo26dzkhl3aq3ehh4clegihnjhiug4uevjligspazwbynwcz3bd.py
# Source Nodes: [setitem_23, setitem_24, setitem_25], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
# setitem_23 => copy_23, lift_fresh_copy_29
# setitem_24 => copy_24, full_default_9, lift_fresh_copy_30
# setitem_25 => copy_25, lift_fresh_copy_31
triton_poi_fused_fill_lift_fresh_slice_66 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[8192],
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_fill_lift_fresh_slice_66', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4900
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = (xindex // 70)
x2 = xindex
x0 = xindex % 70
tmp55 = tl.load(in_out_ptr0 + (x2), xmask)
tmp0 = x1
tmp1 = tl.full([1], 67, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.load(in_ptr0 + ((-4690) + x2), tmp2 & xmask, other=0.0)
tmp4 = tl.full(tmp3.shape, 0.0, tmp3.dtype)
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = x0
tmp7 = tl.full([1], 63, tl.int64)
tmp8 = tmp6 >= tmp7
tmp9 = tmp6 < tmp1
tmp10 = tmp8 & tmp9
tmp11 = tmp10 & tmp2
tmp12 = 7.0
tmp13 = tl.full(tmp12.shape, 0.0, tmp12.dtype)
tmp14 = tl.where(tmp11, tmp12, tmp13)
tmp15 = tmp2 & tmp2
tmp16 = tmp6 < tmp7
tmp17 = tmp16 & tmp15
tmp18 = 6.0
tmp19 = tl.full(tmp18.shape, 0.0, tmp18.dtype)
tmp20 = tl.where(tmp17, tmp18, tmp19)
tmp21 = 0.0
tmp22 = tl.where(tmp16, tmp20, tmp21)
tmp23 = tl.full(tmp22.shape, 0.0, tmp22.dtype)
tmp24 = tl.where(tmp15, tmp22, tmp23)
tmp25 = tmp0 >= tmp7
tmp26 = tmp0 < tmp1
tmp27 = tmp25 & tmp26
tmp28 = tmp27 & tmp2
tmp29 = tmp6 >= tmp1
tmp30 = tmp29 & tmp28
tmp31 = 5.0
tmp32 = tl.full(tmp31.shape, 0.0, tmp31.dtype)
tmp33 = tl.where(tmp30, tmp31, tmp32)
tmp34 = tl.load(in_out_ptr0 + (x2), tmp28 & xmask, other=0.0)
tmp35 = tl.where(tmp29, tmp33, tmp34)
tmp36 = tl.full(tmp35.shape, 0.0, tmp35.dtype)
tmp37 = tl.where(tmp28, tmp35, tmp36)
tmp38 = tl.load(in_out_ptr0 + (x2), tmp2 & xmask, other=0.0)
tmp39 = tl.where(tmp27, tmp37, tmp38)
tmp40 = tl.where(tmp2, tmp24, tmp39)
tmp41 = tl.where(tmp10, tmp14, tmp40)
tmp42 = tl.full(tmp41.shape, 0.0, tmp41.dtype)
tmp43 = tl.where(tmp2, tmp41, tmp42)
tmp44 = tmp16 & tmp2
tmp45 = tl.where(tmp44, tmp18, tmp19)
tmp46 = tl.where(tmp16, tmp45, tmp21)
tmp47 = tl.full(tmp46.shape, 0.0, tmp46.dtype)
tmp48 = tl.where(tmp2, tmp46, tmp47)
tmp49 = tmp29 & tmp27
tmp50 = tl.where(tmp49, tmp31, tmp32)
tmp51 = tl.load(in_out_ptr0 + (x2), tmp27 & xmask, other=0.0)
tmp52 = tl.where(tmp29, tmp50, tmp51)
tmp53 = tl.full(tmp52.shape, 0.0, tmp52.dtype)
tmp54 = tl.where(tmp27, tmp52, tmp53)
tmp56 = tl.where(tmp27, tmp54, tmp55)
tmp57 = tl.where(tmp2, tmp48, tmp56)
tmp58 = tl.where(tmp2, tmp43, tmp57)
tmp59 = tl.where(tmp2, tmp5, tmp58)
tl.store(in_out_ptr0 + (x2), tmp59, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/6z/c6zbanuvwcqjxdnkk3jjviuhk6fvpanrlsj4nzjsrldlmzsgyaei.py
# Source Nodes: [attn_26, attn_28, matmul_11], Original ATen: [aten._softmax, aten._to_copy, aten.add]
# attn_26 => add_65
# attn_28 => amax_5, div_11, exp_5, sub_24, sum_6
# matmul_11 => convert_element_type_180
triton_per_fused__softmax__to_copy_add_67 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[131072, 64],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*i64', 2: '*fp32', 3: '*fp32', 4: '*fp16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__softmax__to_copy_add_67', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 78400
rnumel = 49
RBLOCK: tl.constexpr = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
rmask = rindex < rnumel
r3 = rindex
x4 = xindex
x0 = xindex % 49
x1 = (xindex // 49) % 16
x2 = (xindex // 784)
x5 = (xindex // 49)
tmp0 = tl.load(in_ptr0 + (r3 + (49*x4)), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r3 + (49*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp10 = tl.load(in_ptr3 + ((7*(x2 % 10)) + (70*(r3 // 7)) + (490*(x2 // 10)) + (r3 % 7)), rmask & xmask, eviction_policy='evict_last', other=0.0)
tmp11 = tl.load(in_ptr3 + ((7*(x2 % 10)) + (70*(x0 // 7)) + (490*(x2 // 10)) + (x0 % 7)), xmask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tl.full([XBLOCK, RBLOCK], 169, tl.int32)
tmp4 = tmp2 + tmp3
tmp5 = tmp2 < 0
tmp6 = tl.where(tmp5, tmp4, tmp2)
tl.device_assert(((0 <= tmp6) & (tmp6 < 169)) | ~(rmask & xmask), "index out of bounds: 0 <= tmp6 < 169")
tmp8 = tl.load(in_ptr2 + (x1 + (16*tmp6)), rmask & xmask, eviction_policy='evict_last')
tmp9 = tmp1 + tmp8
tmp12 = tmp10 - tmp11
tmp13 = 0.0
tmp14 = tmp12 == tmp13
tmp15 = tmp12 != tmp13
tmp16 = -100.0
tmp17 = tl.where(tmp15, tmp16, tmp12)
tmp18 = tl.where(tmp14, tmp13, tmp17)
tmp19 = tmp9 + tmp18
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp22 = tl.where(rmask & xmask, tmp20, float("-inf"))
tmp23 = triton_helpers.max2(tmp22, 1)[:, None]
tmp24 = tmp19 - tmp23
tmp25 = tl_math.exp(tmp24)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
tmp28 = tl.where(rmask & xmask, tmp26, 0)
tmp29 = tl.sum(tmp28, 1)[:, None]
tmp30 = tmp25 / tmp29
tmp31 = tmp30.to(tl.float32)
tl.store(out_ptr3 + (r3 + (49*x0) + (2432*x5)), tmp31, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/6x/c6xasej3mmacmhpbp2ayc2wmpde5tmj3algi7vysqjsht3g5bcoy.py
# Source Nodes: [layer_norm_16, x_123, x_124], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# layer_norm_16 => add_69, add_70, convert_element_type_188, mul_59, mul_60, rsqrt_16, sub_25, var_mean_16
# x_123 => add_68
# x_124 => convert_element_type_191
triton_red_fused__to_copy_add_native_layer_norm_68 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp16', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_native_layer_norm_68', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 8, 'num_reduction': 2, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr3, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 4624
rnumel = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
tmp8_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp8_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
tmp8_mean_next, tmp8_m2_next, tmp8_weight_next = triton_helpers.welford_reduce(
tmp7, tmp8_mean, tmp8_m2, tmp8_weight, roffset == 0
)
tmp8_mean = tl.where(rmask & xmask, tmp8_mean_next, tmp8_mean)
tmp8_m2 = tl.where(rmask & xmask, tmp8_m2_next, tmp8_m2)
tmp8_weight = tl.where(rmask & xmask, tmp8_weight_next, tmp8_weight)
tmp8_tmp, tmp9_tmp, tmp10_tmp = triton_helpers.welford(
tmp8_mean, tmp8_m2, tmp8_weight, 1
)
tmp8 = tmp8_tmp[:, None]
tmp9 = tmp9_tmp[:, None]
tmp10 = tmp10_tmp[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp11 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp13 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp25 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp27 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp12 + tmp14
tmp16 = tmp11 + tmp15
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp17 - tmp8
tmp19 = 512.0
tmp20 = tmp9 / tmp19
tmp21 = 1e-05
tmp22 = tmp20 + tmp21
tmp23 = libdevice.rsqrt(tmp22)
tmp24 = tmp18 * tmp23
tmp26 = tmp24 * tmp25
tmp28 = tmp26 + tmp27
tmp29 = tmp28.to(tl.float32)
tl.store(out_ptr3 + (r1 + (512*x0)), tmp29, rmask & xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/rv/crvoqdpaf6qnpdaldpzwrrgi7gxvyresscwiboxzh75r7txhfnnf.py
# Source Nodes: [x_123, x_129, x_130], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_123 => add_68
# x_129 => add_72
# x_130 => convert_element_type_202, var_mean_17
triton_per_fused__to_copy_add_native_layer_norm_69 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp16', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_69', 'mutated_arg_names': ['in_out_ptr0'], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, rnumel):
xnumel = 4624
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_out_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr3 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp10 = tmp5 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask & xmask, tmp12, 0)
tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
tmp17 = tl.where(rmask & xmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tl.full([1], 512, tl.int32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp18 / tmp20
tmp22 = tmp12 - tmp21
tmp23 = tmp22 * tmp22
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask & xmask, tmp24, 0)
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tl.store(in_out_ptr0 + (r1 + (512*x0)), tmp10, rmask & xmask)
tl.store(out_ptr0 + (x0), tmp21, xmask)
tl.store(out_ptr1 + (x0), tmp27, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/i3/ci3pm4qxa5bdompouf64ojpj35sx6k5ehnbqcwoemqypakfg6p26.py
# Source Nodes: [x_419, x_425, x_out_2], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
# x_419 => add_236
# x_425 => add_240
# x_out_2 => convert_element_type_698, var_mean_49
triton_per_fused__to_copy_add_native_layer_norm_70 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.persistent_reduction(
size_hints=[8192, 512],
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp16', 2: '*fp32', 3: '*fp16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: '*fp32', 8: 'i32', 9: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_native_layer_norm_70', 'mutated_arg_names': [], 'no_x_dim': True, 'num_load': 5, 'num_reduction': 4, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel):
xnumel = 4624
XBLOCK: tl.constexpr = 1
rnumel = 512
RBLOCK: tl.constexpr = 512
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
rmask = rindex < rnumel
r1 = rindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r1 + (512*(((67 + (x0 % 68)) % 70) % 7)) + (3584*(((67 + (x0 // 68)) % 70) % 7)) + (25088*(((67 + (x0 % 68)) % 70) // 7)) + (250880*(((67 + (x0 // 68)) % 70) // 7))), rmask & xmask, other=0.0).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp6 = tl.load(in_ptr3 + (r1 + (512*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp7 = tl.load(in_ptr4 + (r1), rmask, eviction_policy='evict_last', other=0.0)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 + tmp3
tmp5 = tmp0 + tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = tmp6 + tmp8
tmp10 = tmp5 + tmp9
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.broadcast_to(tmp11, [RBLOCK])
tmp14 = tl.where(rmask & xmask, tmp12, 0)
tmp15 = tl.broadcast_to(tmp12, [RBLOCK])
tmp17 = tl.where(rmask & xmask, tmp15, 0)
tmp18 = triton_helpers.promote_to_tensor(tl.sum(tmp17, 0))
tmp19 = tl.full([1], 512, tl.int32)
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp18 / tmp20
tmp22 = tmp12 - tmp21
tmp23 = tmp22 * tmp22
tmp24 = tl.broadcast_to(tmp23, [RBLOCK])
tmp26 = tl.where(rmask & xmask, tmp24, 0)
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tl.store(out_ptr0 + (r1 + (512*x0)), tmp11, rmask & xmask)
tl.store(out_ptr1 + (x0), tmp21, xmask)
tl.store(out_ptr2 + (x0), tmp27, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/uz/cuzwnuuhvp4xvn2ygexj3suttmbl2cv7nhkqtwzvahp3hcuf5ncx.py
# Source Nodes: [out], Original ATen: [aten.clone]
# out => clone_27
triton_poi_fused_clone_71 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[128, 131072], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_71', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 128
xnumel = 73984
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + (128*x1)), xmask & ymask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
tmp2 = tmp0 - tmp1
tmp4 = 128.0
tmp5 = tmp3 / tmp4
tmp6 = 1e-05
tmp7 = tmp5 + tmp6
tmp8 = libdevice.rsqrt(tmp7)
tmp9 = tmp2 * tmp8
tmp11 = tmp9 * tmp10
tmp13 = tmp11 + tmp12
tl.store(out_ptr0 + (x1 + (73984*y0)), tmp13, xmask & ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/5s/c5s7m6ozgfpwmvadtkkx5yxo36i4abkudvpqhbmoam4ymufqrcsm.py
# Source Nodes: [out_1], Original ATen: [aten.clone]
# out_1 => clone_53
triton_poi_fused_clone_72 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[256, 32768], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp16', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_72', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 256
xnumel = 18496
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + (256*x1)), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
tmp4 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
tmp11 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
tmp13 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp1 - tmp2
tmp5 = 256.0
tmp6 = tmp4 / tmp5
tmp7 = 1e-05
tmp8 = tmp6 + tmp7
tmp9 = libdevice.rsqrt(tmp8)
tmp10 = tmp3 * tmp9
tmp12 = tmp10 * tmp11
tmp14 = tmp12 + tmp13
tl.store(out_ptr0 + (x1 + (18496*y0)), tmp14, xmask & ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_root/tc/ctcwm3wctt5jaq4f7ptheotzpz5uo4d7jkikv5w3fzgbb6we22en.py
# Source Nodes: [out_2], Original ATen: [aten.clone]
# out_2 => clone_271
triton_poi_fused_clone_73 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
@triton_heuristics.pointwise(
size_hints=[512, 8192], tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32', 6: 'i32', 7: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=58), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clone_73', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': '72c34bdb145549777ca2f0838f26abe42bb446cf528c78d229508b5a55e67a78', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': False, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 512
xnumel = 4624
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + (512*x1)), xmask & ymask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (x1), xmask, eviction_policy='evict_last')
tmp3 = tl.load(in_ptr2 + (x1), xmask, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr3 + (y0), ymask, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (y0), ymask, eviction_policy='evict_last')
tmp2 = tmp0 - tmp1
tmp4 = 512.0
tmp5 = tmp3 / tmp4
tmp6 = 1e-05
tmp7 = tmp5 + tmp6
tmp8 = libdevice.rsqrt(tmp7)
tmp9 = tmp2 * tmp8
tmp11 = tmp9 * tmp10
tmp13 = tmp11 + tmp12
tl.store(out_ptr0 + (x1 + (4624*y0)), tmp13, xmask & ymask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 3, 1088, 1088), (3551232, 1183744, 1088, 1))
assert_size_stride(arg1_1, (128, 3, 4, 4), (48, 16, 4, 1))
assert_size_stride(arg2_1, (128, ), (1, ))
assert_size_stride(arg3_1, (128, ), (1, ))
assert_size_stride(arg4_1, (128, ), (1, ))
assert_size_stride(arg5_1, (128, ), (1, ))
assert_size_stride(arg6_1, (128, ), (1, ))
assert_size_stride(arg7_1, (384, 128), (128, 1))
assert_size_stride(arg8_1, (384, ), (1, ))
assert_size_stride(arg9_1, (169, 4), (4, 1))
assert_size_stride(arg10_1, (49, 49), (49, 1))
assert_size_stride(arg11_1, (128, 128), (128, 1))
assert_size_stride(arg12_1, (128, ), (1, ))
assert_size_stride(arg13_1, (128, ), (1, ))
assert_size_stride(arg14_1, (128, ), (1, ))
assert_size_stride(arg15_1, (512, 128), (128, 1))
assert_size_stride(arg16_1, (512, ), (1, ))
assert_size_stride(arg17_1, (128, 512), (512, 1))
assert_size_stride(arg18_1, (128, ), (1, ))
assert_size_stride(arg19_1, (128, ), (1, ))
assert_size_stride(arg20_1, (128, ), (1, ))
assert_size_stride(arg21_1, (384, 128), (128, 1))
assert_size_stride(arg22_1, (384, ), (1, ))
assert_size_stride(arg23_1, (169, 4), (4, 1))
assert_size_stride(arg24_1, (49, 49), (49, 1))
assert_size_stride(arg25_1, (128, 128), (128, 1))
assert_size_stride(arg26_1, (128, ), (1, ))
assert_size_stride(arg27_1, (128, ), (1, ))
assert_size_stride(arg28_1, (128, ), (1, ))
assert_size_stride(arg29_1, (512, 128), (128, 1))
assert_size_stride(arg30_1, (512, ), (1, ))
assert_size_stride(arg31_1, (128, 512), (512, 1))
assert_size_stride(arg32_1, (128, ), (1, ))
assert_size_stride(arg33_1, (512, ), (1, ))
assert_size_stride(arg34_1, (512, ), (1, ))
assert_size_stride(arg35_1, (256, 512), (512, 1))
assert_size_stride(arg36_1, (128, ), (1, ))
assert_size_stride(arg37_1, (128, ), (1, ))
assert_size_stride(arg38_1, (256, ), (1, ))
assert_size_stride(arg39_1, (256, ), (1, ))
assert_size_stride(arg40_1, (768, 256), (256, 1))
assert_size_stride(arg41_1, (768, ), (1, ))
assert_size_stride(arg42_1, (169, 8), (8, 1))
assert_size_stride(arg43_1, (49, 49), (49, 1))
assert_size_stride(arg44_1, (256, 256), (256, 1))
assert_size_stride(arg45_1, (256, ), (1, ))
assert_size_stride(arg46_1, (256, ), (1, ))
assert_size_stride(arg47_1, (256, ), (1, ))
assert_size_stride(arg48_1, (1024, 256), (256, 1))
assert_size_stride(arg49_1, (1024, ), (1, ))
assert_size_stride(arg50_1, (256, 1024), (1024, 1))
assert_size_stride(arg51_1, (256, ), (1, ))
assert_size_stride(arg52_1, (256, ), (1, ))
assert_size_stride(arg53_1, (256, ), (1, ))
assert_size_stride(arg54_1, (768, 256), (256, 1))
assert_size_stride(arg55_1, (768, ), (1, ))
assert_size_stride(arg56_1, (169, 8), (8, 1))
assert_size_stride(arg57_1, (49, 49), (49, 1))
assert_size_stride(arg58_1, (256, 256), (256, 1))
assert_size_stride(arg59_1, (256, ), (1, ))
assert_size_stride(arg60_1, (256, ), (1, ))
assert_size_stride(arg61_1, (256, ), (1, ))
assert_size_stride(arg62_1, (1024, 256), (256, 1))
assert_size_stride(arg63_1, (1024, ), (1, ))
assert_size_stride(arg64_1, (256, 1024), (1024, 1))
assert_size_stride(arg65_1, (256, ), (1, ))
assert_size_stride(arg66_1, (1024, ), (1, ))
assert_size_stride(arg67_1, (1024, ), (1, ))
assert_size_stride(arg68_1, (512, 1024), (1024, 1))
assert_size_stride(arg69_1, (256, ), (1, ))
assert_size_stride(arg70_1, (256, ), (1, ))
assert_size_stride(arg71_1, (512, ), (1, ))
assert_size_stride(arg72_1, (512, ), (1, ))
assert_size_stride(arg73_1, (1536, 512), (512, 1))
assert_size_stride(arg74_1, (1536, ), (1, ))
assert_size_stride(arg75_1, (169, 16), (16, 1))
assert_size_stride(arg76_1, (49, 49), (49, 1))
assert_size_stride(arg77_1, (512, 512), (512, 1))
assert_size_stride(arg78_1, (512, ), (1, ))
assert_size_stride(arg79_1, (512, ), (1, ))
assert_size_stride(arg80_1, (512, ), (1, ))
assert_size_stride(arg81_1, (2048, 512), (512, 1))
assert_size_stride(arg82_1, (2048, ), (1, ))
assert_size_stride(arg83_1, (512, 2048), (2048, 1))
assert_size_stride(arg84_1, (512, ), (1, ))
assert_size_stride(arg85_1, (512, ), (1, ))
assert_size_stride(arg86_1, (512, ), (1, ))
assert_size_stride(arg87_1, (1536, 512), (512, 1))
assert_size_stride(arg88_1, (1536, ), (1, ))
assert_size_stride(arg89_1, (169, 16), (16, 1))
assert_size_stride(arg90_1, (49, 49), (49, 1))
assert_size_stride(arg91_1, (512, 512), (512, 1))
assert_size_stride(arg92_1, (512, ), (1, ))
assert_size_stride(arg93_1, (512, ), (1, ))
assert_size_stride(arg94_1, (512, ), (1, ))
assert_size_stride(arg95_1, (2048, 512), (512, 1))
assert_size_stride(arg96_1, (2048, ), (1, ))
assert_size_stride(arg97_1, (512, 2048), (2048, 1))
assert_size_stride(arg98_1, (512, ), (1, ))
assert_size_stride(arg99_1, (512, ), (1, ))
assert_size_stride(arg100_1, (512, ), (1, ))
assert_size_stride(arg101_1, (1536, 512), (512, 1))
assert_size_stride(arg102_1, (1536, ), (1, ))
assert_size_stride(arg103_1, (169, 16), (16, 1))
assert_size_stride(arg104_1, (49, 49), (49, 1))
assert_size_stride(arg105_1, (512, 512), (512, 1))
assert_size_stride(arg106_1, (512, ), (1, ))
assert_size_stride(arg107_1, (512, ), (1, ))
assert_size_stride(arg108_1, (512, ), (1, ))
assert_size_stride(arg109_1, (2048, 512), (512, 1))
assert_size_stride(arg110_1, (2048, ), (1, ))
assert_size_stride(arg111_1, (512, 2048), (2048, 1))
assert_size_stride(arg112_1, (512, ), (1, ))
assert_size_stride(arg113_1, (512, ), (1, ))
assert_size_stride(arg114_1, (512, ), (1, ))
assert_size_stride(arg115_1, (1536, 512), (512, 1))
assert_size_stride(arg116_1, (1536, ), (1, ))
assert_size_stride(arg117_1, (169, 16), (16, 1))
assert_size_stride(arg118_1, (49, 49), (49, 1))
assert_size_stride(arg119_1, (512, 512), (512, 1))
assert_size_stride(arg120_1, (512, ), (1, ))
assert_size_stride(arg121_1, (512, ), (1, ))
assert_size_stride(arg122_1, (512, ), (1, ))
assert_size_stride(arg123_1, (2048, 512), (512, 1))
assert_size_stride(arg124_1, (2048, ), (1, ))
assert_size_stride(arg125_1, (512, 2048), (2048, 1))
assert_size_stride(arg126_1, (512, ), (1, ))
assert_size_stride(arg127_1, (512, ), (1, ))
assert_size_stride(arg128_1, (512, ), (1, ))
assert_size_stride(arg129_1, (1536, 512), (512, 1))
assert_size_stride(arg130_1, (1536, ), (1, ))
assert_size_stride(arg131_1, (169, 16), (16, 1))
assert_size_stride(arg132_1, (49, 49), (49, 1))
assert_size_stride(arg133_1, (512, 512), (512, 1))
assert_size_stride(arg134_1, (512, ), (1, ))
assert_size_stride(arg135_1, (512, ), (1, ))
assert_size_stride(arg136_1, (512, ), (1, ))
assert_size_stride(arg137_1, (2048, 512), (512, 1))
assert_size_stride(arg138_1, (2048, ), (1, ))
assert_size_stride(arg139_1, (512, 2048), (2048, 1))
assert_size_stride(arg140_1, (512, ), (1, ))
assert_size_stride(arg141_1, (512, ), (1, ))
assert_size_stride(arg142_1, (512, ), (1, ))
assert_size_stride(arg143_1, (1536, 512), (512, 1))
assert_size_stride(arg144_1, (1536, ), (1, ))
assert_size_stride(arg145_1, (169, 16), (16, 1))
assert_size_stride(arg146_1, (49, 49), (49, 1))
assert_size_stride(arg147_1, (512, 512), (512, 1))
assert_size_stride(arg148_1, (512, ), (1, ))
assert_size_stride(arg149_1, (512, ), (1, ))
assert_size_stride(arg150_1, (512, ), (1, ))
assert_size_stride(arg151_1, (2048, 512), (512, 1))
assert_size_stride(arg152_1, (2048, ), (1, ))
assert_size_stride(arg153_1, (512, 2048), (2048, 1))
assert_size_stride(arg154_1, (512, ), (1, ))
assert_size_stride(arg155_1, (512, ), (1, ))
assert_size_stride(arg156_1, (512, ), (1, ))
assert_size_stride(arg157_1, (1536, 512), (512, 1))
assert_size_stride(arg158_1, (1536, ), (1, ))
assert_size_stride(arg159_1, (169, 16), (16, 1))
assert_size_stride(arg160_1, (49, 49), (49, 1))
assert_size_stride(arg161_1, (512, 512), (512, 1))
assert_size_stride(arg162_1, (512, ), (1, ))
assert_size_stride(arg163_1, (512, ), (1, ))
assert_size_stride(arg164_1, (512, ), (1, ))
assert_size_stride(arg165_1, (2048, 512), (512, 1))
assert_size_stride(arg166_1, (2048, ), (1, ))
assert_size_stride(arg167_1, (512, 2048), (2048, 1))
assert_size_stride(arg168_1, (512, ), (1, ))
assert_size_stride(arg169_1, (512, ), (1, ))
assert_size_stride(arg170_1, (512, ), (1, ))
assert_size_stride(arg171_1, (1536, 512), (512, 1))
assert_size_stride(arg172_1, (1536, ), (1, ))
assert_size_stride(arg173_1, (169, 16), (16, 1))
assert_size_stride(arg174_1, (49, 49), (49, 1))
assert_size_stride(arg175_1, (512, 512), (512, 1))
assert_size_stride(arg176_1, (512, ), (1, ))
assert_size_stride(arg177_1, (512, ), (1, ))
assert_size_stride(arg178_1, (512, ), (1, ))
assert_size_stride(arg179_1, (2048, 512), (512, 1))
assert_size_stride(arg180_1, (2048, ), (1, ))
assert_size_stride(arg181_1, (512, 2048), (2048, 1))
assert_size_stride(arg182_1, (512, ), (1, ))
assert_size_stride(arg183_1, (512, ), (1, ))
assert_size_stride(arg184_1, (512, ), (1, ))
assert_size_stride(arg185_1, (1536, 512), (512, 1))
assert_size_stride(arg186_1, (1536, ), (1, ))
assert_size_stride(arg187_1, (169, 16), (16, 1))
assert_size_stride(arg188_1, (49, 49), (49, 1))
assert_size_stride(arg189_1, (512, 512), (512, 1))
assert_size_stride(arg190_1, (512, ), (1, ))
assert_size_stride(arg191_1, (512, ), (1, ))
assert_size_stride(arg192_1, (512, ), (1, ))
assert_size_stride(arg193_1, (2048, 512), (512, 1))
assert_size_stride(arg194_1, (2048, ), (1, ))
assert_size_stride(arg195_1, (512, 2048), (2048, 1))
assert_size_stride(arg196_1, (512, ), (1, ))
assert_size_stride(arg197_1, (512, ), (1, ))
assert_size_stride(arg198_1, (512, ), (1, ))
assert_size_stride(arg199_1, (1536, 512), (512, 1))
assert_size_stride(arg200_1, (1536, ), (1, ))
assert_size_stride(arg201_1, (169, 16), (16, 1))
assert_size_stride(arg202_1, (49, 49), (49, 1))
assert_size_stride(arg203_1, (512, 512), (512, 1))
assert_size_stride(arg204_1, (512, ), (1, ))
assert_size_stride(arg205_1, (512, ), (1, ))
assert_size_stride(arg206_1, (512, ), (1, ))
assert_size_stride(arg207_1, (2048, 512), (512, 1))
assert_size_stride(arg208_1, (2048, ), (1, ))
assert_size_stride(arg209_1, (512, 2048), (2048, 1))
assert_size_stride(arg210_1, (512, ), (1, ))
assert_size_stride(arg211_1, (512, ), (1, ))
assert_size_stride(arg212_1, (512, ), (1, ))
assert_size_stride(arg213_1, (1536, 512), (512, 1))
assert_size_stride(arg214_1, (1536, ), (1, ))
assert_size_stride(arg215_1, (169, 16), (16, 1))
assert_size_stride(arg216_1, (49, 49), (49, 1))
assert_size_stride(arg217_1, (512, 512), (512, 1))
assert_size_stride(arg218_1, (512, ), (1, ))
assert_size_stride(arg219_1, (512, ), (1, ))
assert_size_stride(arg220_1, (512, ), (1, ))
assert_size_stride(arg221_1, (2048, 512), (512, 1))
assert_size_stride(arg222_1, (2048, ), (1, ))
assert_size_stride(arg223_1, (512, 2048), (2048, 1))
assert_size_stride(arg224_1, (512, ), (1, ))
assert_size_stride(arg225_1, (512, ), (1, ))
assert_size_stride(arg226_1, (512, ), (1, ))
assert_size_stride(arg227_1, (1536, 512), (512, 1))
assert_size_stride(arg228_1, (1536, ), (1, ))
assert_size_stride(arg229_1, (169, 16), (16, 1))
assert_size_stride(arg230_1, (49, 49), (49, 1))
assert_size_stride(arg231_1, (512, 512), (512, 1))
assert_size_stride(arg232_1, (512, ), (1, ))
assert_size_stride(arg233_1, (512, ), (1, ))
assert_size_stride(arg234_1, (512, ), (1, ))
assert_size_stride(arg235_1, (2048, 512), (512, 1))
assert_size_stride(arg236_1, (2048, ), (1, ))
assert_size_stride(arg237_1, (512, 2048), (2048, 1))
assert_size_stride(arg238_1, (512, ), (1, ))
assert_size_stride(arg239_1, (512, ), (1, ))
assert_size_stride(arg240_1, (512, ), (1, ))
assert_size_stride(arg241_1, (1536, 512), (512, 1))
assert_size_stride(arg242_1, (1536, ), (1, ))
assert_size_stride(arg243_1, (169, 16), (16, 1))
assert_size_stride(arg244_1, (49, 49), (49, 1))
assert_size_stride(arg245_1, (512, 512), (512, 1))
assert_size_stride(arg246_1, (512, ), (1, ))
assert_size_stride(arg247_1, (512, ), (1, ))
assert_size_stride(arg248_1, (512, ), (1, ))
assert_size_stride(arg249_1, (2048, 512), (512, 1))
assert_size_stride(arg250_1, (2048, ), (1, ))
assert_size_stride(arg251_1, (512, 2048), (2048, 1))
assert_size_stride(arg252_1, (512, ), (1, ))
assert_size_stride(arg253_1, (512, ), (1, ))
assert_size_stride(arg254_1, (512, ), (1, ))
assert_size_stride(arg255_1, (1536, 512), (512, 1))
assert_size_stride(arg256_1, (1536, ), (1, ))
assert_size_stride(arg257_1, (169, 16), (16, 1))
assert_size_stride(arg258_1, (49, 49), (49, 1))
assert_size_stride(arg259_1, (512, 512), (512, 1))
assert_size_stride(arg260_1, (512, ), (1, ))
assert_size_stride(arg261_1, (512, ), (1, ))
assert_size_stride(arg262_1, (512, ), (1, ))
assert_size_stride(arg263_1, (2048, 512), (512, 1))
assert_size_stride(arg264_1, (2048, ), (1, ))
assert_size_stride(arg265_1, (512, 2048), (2048, 1))
assert_size_stride(arg266_1, (512, ), (1, ))
assert_size_stride(arg267_1, (512, ), (1, ))
assert_size_stride(arg268_1, (512, ), (1, ))
assert_size_stride(arg269_1, (1536, 512), (512, 1))
assert_size_stride(arg270_1, (1536, ), (1, ))
assert_size_stride(arg271_1, (169, 16), (16, 1))
assert_size_stride(arg272_1, (49, 49), (49, 1))
assert_size_stride(arg273_1, (512, 512), (512, 1))
assert_size_stride(arg274_1, (512, ), (1, ))
assert_size_stride(arg275_1, (512, ), (1, ))
assert_size_stride(arg276_1, (512, ), (1, ))
assert_size_stride(arg277_1, (2048, 512), (512, 1))
assert_size_stride(arg278_1, (2048, ), (1, ))
assert_size_stride(arg279_1, (512, 2048), (2048, 1))
assert_size_stride(arg280_1, (512, ), (1, ))
assert_size_stride(arg281_1, (512, ), (1, ))
assert_size_stride(arg282_1, (512, ), (1, ))
assert_size_stride(arg283_1, (1536, 512), (512, 1))
assert_size_stride(arg284_1, (1536, ), (1, ))
assert_size_stride(arg285_1, (169, 16), (16, 1))
assert_size_stride(arg286_1, (49, 49), (49, 1))
assert_size_stride(arg287_1, (512, 512), (512, 1))
assert_size_stride(arg288_1, (512, ), (1, ))
assert_size_stride(arg289_1, (512, ), (1, ))
assert_size_stride(arg290_1, (512, ), (1, ))
assert_size_stride(arg291_1, (2048, 512), (512, 1))
assert_size_stride(arg292_1, (2048, ), (1, ))
assert_size_stride(arg293_1, (512, 2048), (2048, 1))
assert_size_stride(arg294_1, (512, ), (1, ))
assert_size_stride(arg295_1, (512, ), (1, ))
assert_size_stride(arg296_1, (512, ), (1, ))
assert_size_stride(arg297_1, (1536, 512), (512, 1))
assert_size_stride(arg298_1, (1536, ), (1, ))
assert_size_stride(arg299_1, (169, 16), (16, 1))
assert_size_stride(arg300_1, (49, 49), (49, 1))
assert_size_stride(arg301_1, (512, 512), (512, 1))
assert_size_stride(arg302_1, (512, ), (1, ))
assert_size_stride(arg303_1, (512, ), (1, ))
assert_size_stride(arg304_1, (512, ), (1, ))
assert_size_stride(arg305_1, (2048, 512), (512, 1))
assert_size_stride(arg306_1, (2048, ), (1, ))
assert_size_stride(arg307_1, (512, 2048), (2048, 1))
assert_size_stride(arg308_1, (512, ), (1, ))
assert_size_stride(arg309_1, (512, ), (1, ))
assert_size_stride(arg310_1, (512, ), (1, ))
assert_size_stride(arg311_1, (1536, 512), (512, 1))
assert_size_stride(arg312_1, (1536, ), (1, ))
assert_size_stride(arg313_1, (169, 16), (16, 1))
assert_size_stride(arg314_1, (49, 49), (49, 1))
assert_size_stride(arg315_1, (512, 512), (512, 1))
assert_size_stride(arg316_1, (512, ), (1, ))
assert_size_stride(arg317_1, (512, ), (1, ))
assert_size_stride(arg318_1, (512, ), (1, ))
assert_size_stride(arg319_1, (2048, 512), (512, 1))
assert_size_stride(arg320_1, (2048, ), (1, ))
assert_size_stride(arg321_1, (512, 2048), (2048, 1))
assert_size_stride(arg322_1, (512, ), (1, ))
assert_size_stride(arg323_1, (512, ), (1, ))
assert_size_stride(arg324_1, (512, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1, 3, 1088, 1088), (3551232, 1183744, 1088, 1), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_convolution_0.run(arg0_1, buf0, 3551232, grid=grid(3551232), stream=stream0)
del arg0_1
buf1 = empty_strided_cuda((128, 3, 4, 4), (48, 16, 4, 1), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
triton_poi_fused__to_copy_convolution_1.run(arg1_1, buf1, 6144, grid=grid(6144), stream=stream0)
del arg1_1
buf2 = empty_strided_cuda((128, ), (1, ), torch.float16)
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
triton_poi_fused__to_copy_convolution_2.run(arg2_1, buf2, 128, grid=grid(128), stream=stream0)
del arg2_1
# Source Nodes: [x], Original ATen: [aten._to_copy, aten.convolution]
buf3 = extern_kernels.convolution(buf0, buf1, stride=(4, 4), padding=(0, 0), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
assert_size_stride(buf3, (1, 128, 272, 272), (9469952, 73984, 272, 1))
del buf0
del buf1
buf4 = empty_strided_cuda((1, 73984, 1), (73984, 1, 73984), torch.float32)
buf5 = empty_strided_cuda((1, 73984, 1), (73984, 1, 73984), torch.float32)
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_red_fused__to_copy_native_layer_norm_3.run(buf3, buf2, buf4, buf5, 73984, 128, grid=grid(73984), stream=stream0)
buf7 = empty_strided_cuda((1, 73984, 128), (9469952, 1, 73984), torch.float32)
# Source Nodes: [x_2], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_poi_fused__to_copy_native_layer_norm_4.run(buf3, buf2, buf4, buf5, arg3_1, arg4_1, buf7, 9469952, grid=grid(9469952), stream=stream0)
del arg3_1
del arg4_1
del buf2
buf8 = buf5; del buf5 # reuse
buf9 = buf4; del buf4 # reuse
# Source Nodes: [x_7], Original ATen: [aten.native_layer_norm]
triton_red_fused_native_layer_norm_5.run(buf7, buf8, buf9, 73984, 128, grid=grid(73984), stream=stream0)
buf11 = empty_strided_cuda((1521, 49, 128), (6272, 128, 1), torch.float16)
# Source Nodes: [linear], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_6.run(buf7, buf8, buf9, arg5_1, arg6_1, buf11, 74529, 128, grid=grid(74529, 128), stream=stream0)
del arg5_1
del arg6_1
buf12 = empty_strided_cuda((384, 128), (128, 1), torch.float16)
# Source Nodes: [linear], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(arg7_1, buf12, 49152, grid=grid(49152), stream=stream0)
del arg7_1
buf13 = empty_strided_cuda((74529, 384), (384, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf11, (74529, 128), (128, 1), 0), reinterpret_tensor(buf12, (128, 384), (1, 128), 0), out=buf13)
buf14 = empty_strided_cuda((1521, 4, 49, 32), (6400, 1600, 32, 1), torch.float16)
# Source Nodes: [attn, q_1], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_8.run(buf13, arg8_1, buf14, 9539712, grid=grid(9539712), stream=stream0)
buf15 = empty_strided_cuda((1521, 4, 32, 49), (6400, 1600, 49, 1), torch.float16)
# Source Nodes: [attn], Original ATen: [aten.clone]
triton_poi_fused_clone_9.run(buf13, arg8_1, buf15, 194688, 49, grid=grid(194688, 49), stream=stream0)
buf16 = empty_strided_cuda((6084, 49, 49), (2401, 49, 1), torch.float16)
# Source Nodes: [attn], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf14, (6084, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf15, (6084, 32, 49), (1600, 49, 1), 0), out=buf16)
buf19 = empty_strided_cuda((1521, 4, 49, 49), (9728, 2432, 49, 1), torch.float16)
# Source Nodes: [attn_1, attn_2, matmul_1], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_10.run(buf16, arg10_1, arg9_1, buf19, 298116, 49, grid=grid(298116), stream=stream0)
del arg10_1
del arg9_1
buf20 = reinterpret_tensor(buf15, (1521, 4, 49, 32), (6400, 1600, 32, 1), 0); del buf15 # reuse
# Source Nodes: [matmul_1], Original ATen: [aten.clone]
triton_poi_fused_clone_11.run(buf13, arg8_1, buf20, 9539712, grid=grid(9539712), stream=stream0)
del arg8_1
buf21 = reinterpret_tensor(buf11, (6084, 49, 32), (1568, 32, 1), 0); del buf11 # reuse
# Source Nodes: [matmul_1], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf19, (6084, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf20, (6084, 49, 32), (1600, 32, 1), 0), out=buf21)
buf22 = empty_strided_cuda((1521, 49, 4, 32), (6272, 128, 32, 1), torch.float16)
# Source Nodes: [x_11], Original ATen: [aten.clone]
triton_poi_fused_clone_12.run(buf21, buf22, 9539712, grid=grid(9539712), stream=stream0)
buf23 = empty_strided_cuda((128, 128), (128, 1), torch.float16)
# Source Nodes: [x_12], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_13.run(arg11_1, buf23, 16384, grid=grid(16384), stream=stream0)
del arg11_1
buf24 = reinterpret_tensor(buf21, (74529, 128), (128, 1), 0); del buf21 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf22, (74529, 128), (128, 1), 0), reinterpret_tensor(buf23, (128, 128), (1, 128), 0), out=buf24)
buf28 = reinterpret_tensor(buf3, (1, 73984, 128), (9469952, 128, 1), 0); del buf3 # reuse
# Source Nodes: [layer_norm_2, x_18, x_19], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_14.run(buf7, buf24, arg12_1, arg13_1, arg14_1, buf28, 73984, 128, grid=grid(73984), stream=stream0)
del arg13_1
del arg14_1
buf29 = empty_strided_cuda((512, 128), (128, 1), torch.float16)
# Source Nodes: [x_19], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg15_1, buf29, 65536, grid=grid(65536), stream=stream0)
del arg15_1
buf30 = empty_strided_cuda((73984, 512), (512, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf28, (73984, 128), (128, 1), 0), reinterpret_tensor(buf29, (128, 512), (1, 128), 0), out=buf30)
buf31 = reinterpret_tensor(buf30, (1, 73984, 512), (37879808, 512, 1), 0); del buf30 # reuse
# Source Nodes: [x_20], Original ATen: [aten.gelu]
triton_poi_fused_gelu_16.run(buf31, arg16_1, 37879808, grid=grid(37879808), stream=stream0)
del arg16_1
buf32 = reinterpret_tensor(buf29, (128, 512), (512, 1), 0); del buf29 # reuse
# Source Nodes: [x_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg17_1, buf32, 65536, grid=grid(65536), stream=stream0)
del arg17_1
buf33 = reinterpret_tensor(buf28, (73984, 128), (128, 1), 0); del buf28 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf31, (73984, 512), (512, 1), 0), reinterpret_tensor(buf32, (512, 128), (1, 512), 0), out=buf33)
buf34 = empty_strided_cuda((1, 73984, 128), (9469952, 128, 1), torch.float32)
buf35 = buf9; del buf9 # reuse
buf36 = buf8; del buf8 # reuse
# Source Nodes: [x_18, x_24, x_25], Original ATen: [aten.add, aten.native_layer_norm]
triton_red_fused_add_native_layer_norm_17.run(buf7, buf24, arg12_1, buf33, arg18_1, buf34, buf35, buf36, 73984, 128, grid=grid(73984), stream=stream0)
del arg12_1
del arg18_1
buf38 = empty_strided_cuda((1, 273, 273, 128), (9539712, 34944, 128, 1), torch.float32)
# Source Nodes: [shifted_x, x_27], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_18.run(buf34, buf35, buf36, arg19_1, arg20_1, buf38, 9539712, grid=grid(9539712), stream=stream0)
del arg19_1
del arg20_1
buf39 = reinterpret_tensor(buf24, (1521, 49, 128), (6272, 128, 1), 0); del buf24 # reuse
# Source Nodes: [linear_4], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_19.run(buf38, buf39, 9539712, grid=grid(9539712), stream=stream0)
del buf38
buf40 = buf12; del buf12 # reuse
# Source Nodes: [linear_4], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_7.run(arg21_1, buf40, 49152, grid=grid(49152), stream=stream0)
del arg21_1
buf41 = buf13; del buf13 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf39, (74529, 128), (128, 1), 0), reinterpret_tensor(buf40, (128, 384), (1, 128), 0), out=buf41)
del buf40
buf42 = buf20; del buf20 # reuse
# Source Nodes: [attn_4, q_3], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_8.run(buf41, arg22_1, buf42, 9539712, grid=grid(9539712), stream=stream0)
buf43 = reinterpret_tensor(buf14, (1521, 4, 32, 49), (6400, 1600, 49, 1), 0); del buf14 # reuse
# Source Nodes: [attn_4], Original ATen: [aten.clone]
triton_poi_fused_clone_9.run(buf41, arg22_1, buf43, 194688, 49, grid=grid(194688, 49), stream=stream0)
buf44 = buf16; del buf16 # reuse
# Source Nodes: [attn_4], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf42, (6084, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf43, (6084, 32, 49), (1600, 49, 1), 0), out=buf44)
del buf42
buf45 = empty_strided_cuda((1, 273, 273, 1), (74560, 273, 1, 1), torch.float32)
buf46 = reinterpret_tensor(buf45, (1, 273, 273, 1), (74560, 273, 1, 74560), 0); del buf45 # reuse
# Source Nodes: [img_mask, setitem, setitem_1, setitem_2, setitem_3, setitem_4], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
triton_poi_fused_fill_lift_fresh_slice_zeros_20.run(buf46, 74529, grid=grid(74529), stream=stream0)
buf47 = empty_strided_cuda((1, 3, 273, 1), (819, 273, 1, 819), torch.float32)
# Source Nodes: [setitem_8], Original ATen: [aten.fill, aten.lift_fresh]
triton_poi_fused_fill_lift_fresh_21.run(buf46, buf47, 819, grid=grid(819), stream=stream0)
buf48 = reinterpret_tensor(buf46, (1, 273, 273, 1), (74560, 273, 1, 1), 0); del buf46 # reuse
# Source Nodes: [setitem_5, setitem_6, setitem_7], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
triton_poi_fused_fill_lift_fresh_slice_22.run(buf48, buf47, 74529, grid=grid(74529), stream=stream0)
del buf47
buf52 = buf19; del buf19 # reuse
# Source Nodes: [attn_6, attn_8, matmul_3], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_23.run(buf44, arg24_1, arg23_1, buf48, buf52, 298116, 49, grid=grid(298116), stream=stream0)
del arg23_1
del arg24_1
del buf44
del buf48
buf53 = reinterpret_tensor(buf43, (1521, 4, 49, 32), (6400, 1600, 32, 1), 0); del buf43 # reuse
# Source Nodes: [matmul_3], Original ATen: [aten.clone]
triton_poi_fused_clone_11.run(buf41, arg22_1, buf53, 9539712, grid=grid(9539712), stream=stream0)
del arg22_1
del buf41
buf54 = reinterpret_tensor(buf39, (6084, 49, 32), (1568, 32, 1), 0); del buf39 # reuse
# Source Nodes: [matmul_3], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf52, (6084, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf53, (6084, 49, 32), (1600, 32, 1), 0), out=buf54)
del buf52
del buf53
buf55 = buf22; del buf22 # reuse
# Source Nodes: [x_29], Original ATen: [aten.clone]
triton_poi_fused_clone_12.run(buf54, buf55, 9539712, grid=grid(9539712), stream=stream0)
buf56 = buf23; del buf23 # reuse
# Source Nodes: [x_30], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_13.run(arg25_1, buf56, 16384, grid=grid(16384), stream=stream0)
del arg25_1
buf57 = reinterpret_tensor(buf54, (74529, 128), (128, 1), 0); del buf54 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf55, (74529, 128), (128, 1), 0), reinterpret_tensor(buf56, (128, 128), (1, 128), 0), out=buf57)
del buf55
del buf56
buf62 = reinterpret_tensor(buf33, (1, 73984, 128), (9469952, 128, 1), 0); del buf33 # reuse
# Source Nodes: [layer_norm_4, x_37, x_38], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_24.run(buf34, buf57, arg26_1, arg27_1, arg28_1, buf62, 73984, 128, grid=grid(73984), stream=stream0)
del arg27_1
del arg28_1
buf63 = reinterpret_tensor(buf32, (512, 128), (128, 1), 0); del buf32 # reuse
# Source Nodes: [x_38], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg29_1, buf63, 65536, grid=grid(65536), stream=stream0)
del arg29_1
buf64 = reinterpret_tensor(buf31, (73984, 512), (512, 1), 0); del buf31 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf62, (73984, 128), (128, 1), 0), reinterpret_tensor(buf63, (128, 512), (1, 128), 0), out=buf64)
buf65 = reinterpret_tensor(buf64, (1, 73984, 512), (37879808, 512, 1), 0); del buf64 # reuse
# Source Nodes: [x_39], Original ATen: [aten.gelu]
triton_poi_fused_gelu_16.run(buf65, arg30_1, 37879808, grid=grid(37879808), stream=stream0)
del arg30_1
buf66 = reinterpret_tensor(buf63, (128, 512), (512, 1), 0); del buf63 # reuse
# Source Nodes: [x_41], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg31_1, buf66, 65536, grid=grid(65536), stream=stream0)
del arg31_1
buf67 = reinterpret_tensor(buf62, (73984, 128), (128, 1), 0); del buf62 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf65, (73984, 512), (512, 1), 0), reinterpret_tensor(buf66, (512, 128), (1, 512), 0), out=buf67)
del buf65
buf68 = reinterpret_tensor(buf7, (1, 73984, 128), (9469952, 128, 1), 0); del buf7 # reuse
buf72 = buf36; del buf36 # reuse
buf73 = buf35; del buf35 # reuse
# Source Nodes: [x_37, x_43, x_out], Original ATen: [aten.add, aten.native_layer_norm]
triton_per_fused_add_native_layer_norm_25.run(buf34, buf57, arg26_1, buf67, arg32_1, buf68, buf72, buf73, 73984, 128, grid=grid(73984), stream=stream0)
del arg26_1
del arg32_1
del buf57
buf76 = reinterpret_tensor(buf67, (1, 18496, 512), (9469952, 512, 1), 0); del buf67 # reuse
# Source Nodes: [x_47, x_48], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_red_fused__to_copy_native_layer_norm_26.run(buf68, arg33_1, arg34_1, buf76, 18496, 512, grid=grid(18496), stream=stream0)
del arg33_1
del arg34_1
buf77 = empty_strided_cuda((256, 512), (512, 1), torch.float16)
# Source Nodes: [x_48], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_27.run(arg35_1, buf77, 131072, grid=grid(131072), stream=stream0)
del arg35_1
buf78 = empty_strided_cuda((18496, 256), (256, 1), torch.float16)
# Source Nodes: [x_48], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf76, (18496, 512), (512, 1), 0), reinterpret_tensor(buf77, (512, 256), (1, 512), 0), out=buf78)
del buf77
buf79 = empty_strided_cuda((1, 18496, 1), (18496, 1, 18496), torch.float32)
buf80 = empty_strided_cuda((1, 18496, 1), (18496, 1, 18496), torch.float32)
# Source Nodes: [x_50], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_per_fused__to_copy_native_layer_norm_28.run(buf78, buf79, buf80, 18496, 256, grid=grid(18496), stream=stream0)
buf82 = empty_strided_cuda((400, 49, 256), (12544, 256, 1), torch.float16)
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_29.run(buf78, buf79, buf80, arg38_1, arg39_1, buf82, 5017600, grid=grid(5017600), stream=stream0)
del arg38_1
del arg39_1
buf83 = empty_strided_cuda((768, 256), (256, 1), torch.float16)
# Source Nodes: [linear_9], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_30.run(arg40_1, buf83, 196608, grid=grid(196608), stream=stream0)
del arg40_1
buf84 = empty_strided_cuda((19600, 768), (768, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf82, (19600, 256), (256, 1), 0), reinterpret_tensor(buf83, (256, 768), (1, 256), 0), out=buf84)
buf85 = empty_strided_cuda((400, 8, 49, 32), (12800, 1600, 32, 1), torch.float16)
# Source Nodes: [attn_10, q_5], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_31.run(buf84, arg41_1, buf85, 5017600, grid=grid(5017600), stream=stream0)
buf86 = empty_strided_cuda((400, 8, 32, 49), (12800, 1600, 49, 1), torch.float16)
# Source Nodes: [attn_10], Original ATen: [aten.clone]
triton_poi_fused_clone_32.run(buf84, arg41_1, buf86, 102400, 49, grid=grid(102400, 49), stream=stream0)
buf87 = empty_strided_cuda((3200, 49, 49), (2401, 49, 1), torch.float16)
# Source Nodes: [attn_10], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf85, (3200, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf86, (3200, 32, 49), (1600, 49, 1), 0), out=buf87)
buf90 = empty_strided_cuda((400, 8, 49, 49), (19456, 2432, 49, 1), torch.float16)
# Source Nodes: [attn_11, attn_12, matmul_5], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_33.run(buf87, arg43_1, arg42_1, buf90, 156800, 49, grid=grid(156800), stream=stream0)
del arg42_1
del arg43_1
buf91 = reinterpret_tensor(buf86, (400, 8, 49, 32), (12800, 1600, 32, 1), 0); del buf86 # reuse
# Source Nodes: [matmul_5], Original ATen: [aten.clone]
triton_poi_fused_clone_34.run(buf84, arg41_1, buf91, 5017600, grid=grid(5017600), stream=stream0)
del arg41_1
buf92 = reinterpret_tensor(buf82, (3200, 49, 32), (1568, 32, 1), 0); del buf82 # reuse
# Source Nodes: [matmul_5], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf90, (3200, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf91, (3200, 49, 32), (1600, 32, 1), 0), out=buf92)
buf93 = empty_strided_cuda((400, 49, 8, 32), (12544, 256, 32, 1), torch.float16)
# Source Nodes: [x_54], Original ATen: [aten.clone]
triton_poi_fused_clone_35.run(buf92, buf93, 5017600, grid=grid(5017600), stream=stream0)
buf94 = reinterpret_tensor(buf66, (256, 256), (256, 1), 0); del buf66 # reuse
# Source Nodes: [x_55], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg44_1, buf94, 65536, grid=grid(65536), stream=stream0)
del arg44_1
buf95 = reinterpret_tensor(buf92, (19600, 256), (256, 1), 0); del buf92 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf93, (19600, 256), (256, 1), 0), reinterpret_tensor(buf94, (256, 256), (1, 256), 0), out=buf95)
buf99 = empty_strided_cuda((1, 18496, 256), (4734976, 256, 1), torch.float16)
# Source Nodes: [layer_norm_8, x_61, x_62], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_36.run(buf78, buf95, arg45_1, arg46_1, arg47_1, buf99, 18496, 256, grid=grid(18496), stream=stream0)
del arg46_1
del arg47_1
buf100 = empty_strided_cuda((1024, 256), (256, 1), torch.float16)
# Source Nodes: [x_62], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg48_1, buf100, 262144, grid=grid(262144), stream=stream0)
del arg48_1
buf101 = empty_strided_cuda((18496, 1024), (1024, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf99, (18496, 256), (256, 1), 0), reinterpret_tensor(buf100, (256, 1024), (1, 256), 0), out=buf101)
buf102 = reinterpret_tensor(buf101, (1, 18496, 1024), (18939904, 1024, 1), 0); del buf101 # reuse
# Source Nodes: [x_63], Original ATen: [aten.gelu]
triton_poi_fused_gelu_38.run(buf102, arg49_1, 18939904, grid=grid(18939904), stream=stream0)
del arg49_1
buf103 = reinterpret_tensor(buf100, (256, 1024), (1024, 1), 0); del buf100 # reuse
# Source Nodes: [x_65], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg50_1, buf103, 262144, grid=grid(262144), stream=stream0)
del arg50_1
buf104 = reinterpret_tensor(buf99, (18496, 256), (256, 1), 0); del buf99 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf102, (18496, 1024), (1024, 1), 0), reinterpret_tensor(buf103, (1024, 256), (1, 1024), 0), out=buf104)
buf105 = reinterpret_tensor(buf104, (1, 18496, 256), (4734976, 256, 1), 0); del buf104 # reuse
buf106 = buf80; del buf80 # reuse
buf107 = buf79; del buf79 # reuse
# Source Nodes: [x_61, x_67, x_68], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_39.run(buf105, buf78, buf95, arg45_1, arg51_1, buf106, buf107, 18496, 256, grid=grid(18496), stream=stream0)
del arg45_1
del arg51_1
buf109 = empty_strided_cuda((1, 140, 140, 256), (5017600, 35840, 256, 1), torch.float32)
# Source Nodes: [shifted_x_1, x_70], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_40.run(buf105, buf106, buf107, arg52_1, arg53_1, buf109, 5017600, grid=grid(5017600), stream=stream0)
del arg52_1
del arg53_1
buf110 = reinterpret_tensor(buf95, (400, 49, 256), (12544, 256, 1), 0); del buf95 # reuse
# Source Nodes: [linear_13], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_41.run(buf109, buf110, 5017600, grid=grid(5017600), stream=stream0)
del buf109
buf111 = buf83; del buf83 # reuse
# Source Nodes: [linear_13], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_30.run(arg54_1, buf111, 196608, grid=grid(196608), stream=stream0)
del arg54_1
buf112 = buf84; del buf84 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf110, (19600, 256), (256, 1), 0), reinterpret_tensor(buf111, (256, 768), (1, 256), 0), out=buf112)
del buf111
buf113 = buf91; del buf91 # reuse
# Source Nodes: [attn_14, q_7], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_31.run(buf112, arg55_1, buf113, 5017600, grid=grid(5017600), stream=stream0)
buf114 = reinterpret_tensor(buf85, (400, 8, 32, 49), (12800, 1600, 49, 1), 0); del buf85 # reuse
# Source Nodes: [attn_14], Original ATen: [aten.clone]
triton_poi_fused_clone_32.run(buf112, arg55_1, buf114, 102400, 49, grid=grid(102400, 49), stream=stream0)
buf115 = buf87; del buf87 # reuse
# Source Nodes: [attn_14], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf113, (3200, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf114, (3200, 32, 49), (1600, 49, 1), 0), out=buf115)
del buf113
buf116 = empty_strided_cuda((1, 140, 140, 1), (19616, 140, 1, 1), torch.float32)
buf117 = reinterpret_tensor(buf116, (1, 140, 140, 1), (19616, 140, 1, 19616), 0); del buf116 # reuse
# Source Nodes: [img_mask_1, setitem_10, setitem_11, setitem_12, setitem_13, setitem_9], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
triton_poi_fused_fill_lift_fresh_slice_zeros_42.run(buf117, 19600, grid=grid(19600), stream=stream0)
buf118 = empty_strided_cuda((1, 3, 140, 1), (420, 140, 1, 420), torch.float32)
# Source Nodes: [setitem_17], Original ATen: [aten.fill, aten.lift_fresh]
triton_poi_fused_fill_lift_fresh_43.run(buf117, buf118, 420, grid=grid(420), stream=stream0)
buf119 = reinterpret_tensor(buf117, (1, 140, 140, 1), (19616, 140, 1, 1), 0); del buf117 # reuse
# Source Nodes: [setitem_14, setitem_15, setitem_16], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
triton_poi_fused_fill_lift_fresh_slice_44.run(buf119, buf118, 19600, grid=grid(19600), stream=stream0)
del buf118
buf123 = buf90; del buf90 # reuse
# Source Nodes: [attn_16, attn_18, matmul_7], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_45.run(buf115, arg57_1, arg56_1, buf119, buf123, 156800, 49, grid=grid(156800), stream=stream0)
del arg56_1
del arg57_1
del buf115
del buf119
buf124 = reinterpret_tensor(buf114, (400, 8, 49, 32), (12800, 1600, 32, 1), 0); del buf114 # reuse
# Source Nodes: [matmul_7], Original ATen: [aten.clone]
triton_poi_fused_clone_34.run(buf112, arg55_1, buf124, 5017600, grid=grid(5017600), stream=stream0)
del arg55_1
del buf112
buf125 = reinterpret_tensor(buf110, (3200, 49, 32), (1568, 32, 1), 0); del buf110 # reuse
# Source Nodes: [matmul_7], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf123, (3200, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf124, (3200, 49, 32), (1600, 32, 1), 0), out=buf125)
del buf123
del buf124
buf126 = buf93; del buf93 # reuse
# Source Nodes: [x_72], Original ATen: [aten.clone]
triton_poi_fused_clone_35.run(buf125, buf126, 5017600, grid=grid(5017600), stream=stream0)
buf127 = buf94; del buf94 # reuse
# Source Nodes: [x_73], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_15.run(arg58_1, buf127, 65536, grid=grid(65536), stream=stream0)
del arg58_1
buf128 = reinterpret_tensor(buf125, (19600, 256), (256, 1), 0); del buf125 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf126, (19600, 256), (256, 1), 0), reinterpret_tensor(buf127, (256, 256), (1, 256), 0), out=buf128)
del buf126
del buf127
buf133 = reinterpret_tensor(buf78, (1, 18496, 256), (4734976, 256, 1), 0); del buf78 # reuse
# Source Nodes: [layer_norm_10, x_80, x_81], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_46.run(buf105, buf128, arg59_1, arg60_1, arg61_1, buf133, 18496, 256, grid=grid(18496), stream=stream0)
del arg60_1
del arg61_1
buf134 = reinterpret_tensor(buf103, (1024, 256), (256, 1), 0); del buf103 # reuse
# Source Nodes: [x_81], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg62_1, buf134, 262144, grid=grid(262144), stream=stream0)
del arg62_1
buf135 = reinterpret_tensor(buf102, (18496, 1024), (1024, 1), 0); del buf102 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf133, (18496, 256), (256, 1), 0), reinterpret_tensor(buf134, (256, 1024), (1, 256), 0), out=buf135)
buf136 = reinterpret_tensor(buf135, (1, 18496, 1024), (18939904, 1024, 1), 0); del buf135 # reuse
# Source Nodes: [x_82], Original ATen: [aten.gelu]
triton_poi_fused_gelu_38.run(buf136, arg63_1, 18939904, grid=grid(18939904), stream=stream0)
del arg63_1
buf137 = reinterpret_tensor(buf134, (256, 1024), (1024, 1), 0); del buf134 # reuse
# Source Nodes: [x_84], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg64_1, buf137, 262144, grid=grid(262144), stream=stream0)
del arg64_1
buf138 = reinterpret_tensor(buf133, (18496, 256), (256, 1), 0); del buf133 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf136, (18496, 1024), (1024, 1), 0), reinterpret_tensor(buf137, (1024, 256), (1, 1024), 0), out=buf138)
del buf136
buf139 = reinterpret_tensor(buf138, (1, 18496, 256), (4734976, 256, 1), 0); del buf138 # reuse
buf143 = buf107; del buf107 # reuse
buf144 = buf106; del buf106 # reuse
# Source Nodes: [x_80, x_86, x_out_1], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_47.run(buf139, buf105, buf128, arg59_1, arg65_1, buf143, buf144, 18496, 256, grid=grid(18496), stream=stream0)
del arg59_1
del arg65_1
del buf128
buf147 = reinterpret_tensor(buf105, (1, 4624, 1024), (4734976, 1024, 1), 0); del buf105 # reuse
# Source Nodes: [x_90, x_91], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_red_fused__to_copy_native_layer_norm_48.run(buf139, arg66_1, arg67_1, buf147, 4624, 1024, grid=grid(4624), stream=stream0)
del arg66_1
del arg67_1
buf148 = empty_strided_cuda((512, 1024), (1024, 1), torch.float16)
# Source Nodes: [x_91], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_49.run(arg68_1, buf148, 524288, grid=grid(524288), stream=stream0)
del arg68_1
buf149 = empty_strided_cuda((4624, 512), (512, 1), torch.float16)
# Source Nodes: [x_91], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf147, (4624, 1024), (1024, 1), 0), reinterpret_tensor(buf148, (1024, 512), (1, 1024), 0), out=buf149)
del buf147
del buf148
buf150 = empty_strided_cuda((1, 4624, 1), (4640, 1, 4640), torch.float32)
buf151 = empty_strided_cuda((1, 4624, 1), (4640, 1, 4640), torch.float32)
# Source Nodes: [x_93], Original ATen: [aten._to_copy, aten.native_layer_norm]
triton_per_fused__to_copy_native_layer_norm_50.run(buf149, buf150, buf151, 4624, 512, grid=grid(4624), stream=stream0)
buf153 = empty_strided_cuda((100, 49, 512), (25088, 512, 1), torch.float16)
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf149, buf150, buf151, arg71_1, arg72_1, buf153, 2508800, grid=grid(2508800), stream=stream0)
del arg71_1
del arg72_1
buf154 = empty_strided_cuda((1536, 512), (512, 1), torch.float16)
# Source Nodes: [linear_18], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg73_1, buf154, 786432, grid=grid(786432), stream=stream0)
del arg73_1
buf155 = empty_strided_cuda((4900, 1536), (1536, 1), torch.float16)
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf153, (4900, 512), (512, 1), 0), reinterpret_tensor(buf154, (512, 1536), (1, 512), 0), out=buf155)
buf156 = empty_strided_cuda((100, 16, 49, 32), (25600, 1600, 32, 1), torch.float16)
# Source Nodes: [attn_20, q_9], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf155, arg74_1, buf156, 2508800, grid=grid(2508800), stream=stream0)
buf157 = empty_strided_cuda((100, 16, 32, 49), (25600, 1600, 49, 1), torch.float16)
# Source Nodes: [attn_20], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf155, arg74_1, buf157, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf158 = empty_strided_cuda((1600, 49, 49), (2401, 49, 1), torch.float16)
# Source Nodes: [attn_20], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf156, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf157, (1600, 32, 49), (1600, 49, 1), 0), out=buf158)
buf161 = empty_strided_cuda((100, 16, 49, 49), (38912, 2432, 49, 1), torch.float16)
# Source Nodes: [attn_21, attn_22, matmul_9], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf158, arg76_1, arg75_1, buf161, 78400, 49, grid=grid(78400), stream=stream0)
del arg75_1
del arg76_1
buf162 = reinterpret_tensor(buf157, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf157 # reuse
# Source Nodes: [matmul_9], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf155, arg74_1, buf162, 2508800, grid=grid(2508800), stream=stream0)
del arg74_1
buf163 = reinterpret_tensor(buf153, (1600, 49, 32), (1568, 32, 1), 0); del buf153 # reuse
# Source Nodes: [matmul_9], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf161, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf162, (1600, 49, 32), (1600, 32, 1), 0), out=buf163)
buf164 = empty_strided_cuda((100, 49, 16, 32), (25088, 512, 32, 1), torch.float16)
# Source Nodes: [x_97], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf163, buf164, 2508800, grid=grid(2508800), stream=stream0)
buf165 = reinterpret_tensor(buf137, (512, 512), (512, 1), 0); del buf137 # reuse
# Source Nodes: [x_98], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg77_1, buf165, 262144, grid=grid(262144), stream=stream0)
del arg77_1
buf166 = reinterpret_tensor(buf163, (4900, 512), (512, 1), 0); del buf163 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf164, (4900, 512), (512, 1), 0), reinterpret_tensor(buf165, (512, 512), (1, 512), 0), out=buf166)
buf170 = empty_strided_cuda((1, 4624, 512), (2367488, 512, 1), torch.float16)
# Source Nodes: [layer_norm_14, x_104, x_105], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf149, buf166, arg78_1, arg79_1, arg80_1, buf170, 4624, 512, grid=grid(4624), stream=stream0)
del arg79_1
del arg80_1
buf171 = empty_strided_cuda((2048, 512), (512, 1), torch.float16)
# Source Nodes: [x_105], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg81_1, buf171, 1048576, grid=grid(1048576), stream=stream0)
del arg81_1
buf172 = reinterpret_tensor(buf76, (4624, 2048), (2048, 1), 0); del buf76 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf170, (4624, 512), (512, 1), 0), reinterpret_tensor(buf171, (512, 2048), (1, 512), 0), out=buf172)
buf173 = reinterpret_tensor(buf172, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf172 # reuse
# Source Nodes: [x_106], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf173, arg82_1, 9469952, grid=grid(9469952), stream=stream0)
del arg82_1
buf174 = reinterpret_tensor(buf171, (512, 2048), (2048, 1), 0); del buf171 # reuse
# Source Nodes: [x_108], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg83_1, buf174, 1048576, grid=grid(1048576), stream=stream0)
del arg83_1
buf175 = reinterpret_tensor(buf170, (4624, 512), (512, 1), 0); del buf170 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf173, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf174, (2048, 512), (1, 2048), 0), out=buf175)
buf176 = reinterpret_tensor(buf175, (1, 4624, 512), (2367488, 512, 1), 0); del buf175 # reuse
buf177 = buf151; del buf151 # reuse
buf178 = buf150; del buf150 # reuse
# Source Nodes: [x_104, x_110, x_111], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf176, buf149, buf166, arg78_1, arg84_1, buf177, buf178, 4624, 512, grid=grid(4624), stream=stream0)
del arg78_1
del arg84_1
buf180 = empty_strided_cuda((1, 70, 70, 512), (2508800, 35840, 512, 1), torch.float32)
# Source Nodes: [shifted_x_2, x_113], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf176, buf177, buf178, arg85_1, arg86_1, buf180, 2508800, grid=grid(2508800), stream=stream0)
del arg85_1
del arg86_1
buf181 = reinterpret_tensor(buf166, (100, 49, 512), (25088, 512, 1), 0); del buf166 # reuse
# Source Nodes: [linear_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf180, buf181, 2508800, grid=grid(2508800), stream=stream0)
buf182 = buf154; del buf154 # reuse
# Source Nodes: [linear_22], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg87_1, buf182, 786432, grid=grid(786432), stream=stream0)
del arg87_1
buf183 = buf155; del buf155 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf181, (4900, 512), (512, 1), 0), reinterpret_tensor(buf182, (512, 1536), (1, 512), 0), out=buf183)
buf184 = buf162; del buf162 # reuse
# Source Nodes: [attn_24, q_11], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf183, arg88_1, buf184, 2508800, grid=grid(2508800), stream=stream0)
buf185 = reinterpret_tensor(buf156, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf156 # reuse
# Source Nodes: [attn_24], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf183, arg88_1, buf185, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf186 = buf158; del buf158 # reuse
# Source Nodes: [attn_24], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf184, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf185, (1600, 32, 49), (1600, 49, 1), 0), out=buf186)
buf187 = empty_strided_cuda((1, 70, 70, 1), (4928, 70, 1, 1), torch.float32)
buf188 = reinterpret_tensor(buf187, (1, 70, 70, 1), (4928, 70, 1, 4928), 0); del buf187 # reuse
# Source Nodes: [img_mask_2, setitem_18, setitem_19, setitem_20, setitem_21, setitem_22], Original ATen: [aten.fill, aten.lift_fresh, aten.slice, aten.zeros]
triton_poi_fused_fill_lift_fresh_slice_zeros_64.run(buf188, 4900, grid=grid(4900), stream=stream0)
buf189 = empty_strided_cuda((1, 3, 70, 1), (210, 70, 1, 210), torch.float32)
# Source Nodes: [setitem_26], Original ATen: [aten.fill, aten.lift_fresh]
triton_poi_fused_fill_lift_fresh_65.run(buf188, buf189, 210, grid=grid(210), stream=stream0)
buf190 = reinterpret_tensor(buf188, (1, 70, 70, 1), (4928, 70, 1, 1), 0); del buf188 # reuse
# Source Nodes: [setitem_23, setitem_24, setitem_25], Original ATen: [aten.fill, aten.lift_fresh, aten.slice]
triton_poi_fused_fill_lift_fresh_slice_66.run(buf190, buf189, 4900, grid=grid(4900), stream=stream0)
del buf189
buf194 = buf161; del buf161 # reuse
# Source Nodes: [attn_26, attn_28, matmul_11], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf186, arg90_1, arg89_1, buf190, buf194, 78400, 49, grid=grid(78400), stream=stream0)
del arg89_1
del arg90_1
buf195 = reinterpret_tensor(buf185, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf185 # reuse
# Source Nodes: [matmul_11], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf183, arg88_1, buf195, 2508800, grid=grid(2508800), stream=stream0)
del arg88_1
buf196 = reinterpret_tensor(buf181, (1600, 49, 32), (1568, 32, 1), 0); del buf181 # reuse
# Source Nodes: [matmul_11], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf194, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf195, (1600, 49, 32), (1600, 32, 1), 0), out=buf196)
buf197 = buf164; del buf164 # reuse
# Source Nodes: [x_115], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf196, buf197, 2508800, grid=grid(2508800), stream=stream0)
buf198 = buf165; del buf165 # reuse
# Source Nodes: [x_116], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg91_1, buf198, 262144, grid=grid(262144), stream=stream0)
del arg91_1
buf199 = reinterpret_tensor(buf196, (4900, 512), (512, 1), 0); del buf196 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf197, (4900, 512), (512, 1), 0), reinterpret_tensor(buf198, (512, 512), (1, 512), 0), out=buf199)
buf204 = reinterpret_tensor(buf149, (1, 4624, 512), (2367488, 512, 1), 0); del buf149 # reuse
# Source Nodes: [layer_norm_16, x_123, x_124], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf176, buf199, arg92_1, arg93_1, arg94_1, buf204, 4624, 512, grid=grid(4624), stream=stream0)
del arg93_1
del arg94_1
buf205 = reinterpret_tensor(buf174, (2048, 512), (512, 1), 0); del buf174 # reuse
# Source Nodes: [x_124], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg95_1, buf205, 1048576, grid=grid(1048576), stream=stream0)
del arg95_1
buf206 = reinterpret_tensor(buf173, (4624, 2048), (2048, 1), 0); del buf173 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf204, (4624, 512), (512, 1), 0), reinterpret_tensor(buf205, (512, 2048), (1, 512), 0), out=buf206)
buf207 = reinterpret_tensor(buf206, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf206 # reuse
# Source Nodes: [x_125], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf207, arg96_1, 9469952, grid=grid(9469952), stream=stream0)
del arg96_1
buf208 = reinterpret_tensor(buf205, (512, 2048), (2048, 1), 0); del buf205 # reuse
# Source Nodes: [x_127], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg97_1, buf208, 1048576, grid=grid(1048576), stream=stream0)
del arg97_1
buf209 = reinterpret_tensor(buf204, (4624, 512), (512, 1), 0); del buf204 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf207, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf208, (2048, 512), (1, 2048), 0), out=buf209)
buf210 = reinterpret_tensor(buf209, (1, 4624, 512), (2367488, 512, 1), 0); del buf209 # reuse
buf211 = buf178; del buf178 # reuse
buf212 = buf177; del buf177 # reuse
# Source Nodes: [x_123, x_129, x_130], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf210, buf176, buf199, arg92_1, arg98_1, buf211, buf212, 4624, 512, grid=grid(4624), stream=stream0)
del arg92_1
del arg98_1
buf214 = reinterpret_tensor(buf199, (100, 49, 512), (25088, 512, 1), 0); del buf199 # reuse
# Source Nodes: [linear_26], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf210, buf211, buf212, arg99_1, arg100_1, buf214, 2508800, grid=grid(2508800), stream=stream0)
del arg100_1
del arg99_1
buf215 = buf182; del buf182 # reuse
# Source Nodes: [linear_26], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg101_1, buf215, 786432, grid=grid(786432), stream=stream0)
del arg101_1
buf216 = buf183; del buf183 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf214, (4900, 512), (512, 1), 0), reinterpret_tensor(buf215, (512, 1536), (1, 512), 0), out=buf216)
buf217 = buf195; del buf195 # reuse
# Source Nodes: [attn_30, q_13], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf216, arg102_1, buf217, 2508800, grid=grid(2508800), stream=stream0)
buf218 = reinterpret_tensor(buf184, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf184 # reuse
# Source Nodes: [attn_30], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf216, arg102_1, buf218, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf219 = buf186; del buf186 # reuse
# Source Nodes: [attn_30], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf217, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf218, (1600, 32, 49), (1600, 49, 1), 0), out=buf219)
buf222 = buf194; del buf194 # reuse
# Source Nodes: [attn_31, attn_32, matmul_13], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf219, arg104_1, arg103_1, buf222, 78400, 49, grid=grid(78400), stream=stream0)
del arg103_1
del arg104_1
buf223 = reinterpret_tensor(buf218, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf218 # reuse
# Source Nodes: [matmul_13], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf216, arg102_1, buf223, 2508800, grid=grid(2508800), stream=stream0)
del arg102_1
buf224 = reinterpret_tensor(buf214, (1600, 49, 32), (1568, 32, 1), 0); del buf214 # reuse
# Source Nodes: [matmul_13], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf222, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf223, (1600, 49, 32), (1600, 32, 1), 0), out=buf224)
buf225 = buf197; del buf197 # reuse
# Source Nodes: [x_134], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf224, buf225, 2508800, grid=grid(2508800), stream=stream0)
buf226 = buf198; del buf198 # reuse
# Source Nodes: [x_135], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg105_1, buf226, 262144, grid=grid(262144), stream=stream0)
del arg105_1
buf227 = reinterpret_tensor(buf224, (4900, 512), (512, 1), 0); del buf224 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf225, (4900, 512), (512, 1), 0), reinterpret_tensor(buf226, (512, 512), (1, 512), 0), out=buf227)
buf231 = buf176; del buf176 # reuse
# Source Nodes: [layer_norm_18, x_141, x_142], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf210, buf227, arg106_1, arg107_1, arg108_1, buf231, 4624, 512, grid=grid(4624), stream=stream0)
del arg107_1
del arg108_1
buf232 = reinterpret_tensor(buf208, (2048, 512), (512, 1), 0); del buf208 # reuse
# Source Nodes: [x_142], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg109_1, buf232, 1048576, grid=grid(1048576), stream=stream0)
del arg109_1
buf233 = reinterpret_tensor(buf207, (4624, 2048), (2048, 1), 0); del buf207 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf231, (4624, 512), (512, 1), 0), reinterpret_tensor(buf232, (512, 2048), (1, 512), 0), out=buf233)
buf234 = reinterpret_tensor(buf233, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf233 # reuse
# Source Nodes: [x_143], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf234, arg110_1, 9469952, grid=grid(9469952), stream=stream0)
del arg110_1
buf235 = reinterpret_tensor(buf232, (512, 2048), (2048, 1), 0); del buf232 # reuse
# Source Nodes: [x_145], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg111_1, buf235, 1048576, grid=grid(1048576), stream=stream0)
del arg111_1
buf236 = reinterpret_tensor(buf231, (4624, 512), (512, 1), 0); del buf231 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf234, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf235, (2048, 512), (1, 2048), 0), out=buf236)
buf237 = reinterpret_tensor(buf236, (1, 4624, 512), (2367488, 512, 1), 0); del buf236 # reuse
buf238 = buf212; del buf212 # reuse
buf239 = buf211; del buf211 # reuse
# Source Nodes: [x_141, x_147, x_148], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf237, buf210, buf227, arg106_1, arg112_1, buf238, buf239, 4624, 512, grid=grid(4624), stream=stream0)
del arg106_1
del arg112_1
buf241 = buf180; del buf180 # reuse
# Source Nodes: [shifted_x_3, x_150], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf237, buf238, buf239, arg113_1, arg114_1, buf241, 2508800, grid=grid(2508800), stream=stream0)
del arg113_1
del arg114_1
buf242 = reinterpret_tensor(buf227, (100, 49, 512), (25088, 512, 1), 0); del buf227 # reuse
# Source Nodes: [linear_30], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf241, buf242, 2508800, grid=grid(2508800), stream=stream0)
buf243 = buf215; del buf215 # reuse
# Source Nodes: [linear_30], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg115_1, buf243, 786432, grid=grid(786432), stream=stream0)
del arg115_1
buf244 = buf216; del buf216 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf242, (4900, 512), (512, 1), 0), reinterpret_tensor(buf243, (512, 1536), (1, 512), 0), out=buf244)
buf245 = buf223; del buf223 # reuse
# Source Nodes: [attn_34, q_15], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf244, arg116_1, buf245, 2508800, grid=grid(2508800), stream=stream0)
buf246 = reinterpret_tensor(buf217, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf217 # reuse
# Source Nodes: [attn_34], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf244, arg116_1, buf246, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf247 = buf219; del buf219 # reuse
# Source Nodes: [attn_34], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf245, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf246, (1600, 32, 49), (1600, 49, 1), 0), out=buf247)
buf251 = buf222; del buf222 # reuse
# Source Nodes: [attn_36, attn_38, matmul_15], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf247, arg118_1, arg117_1, buf190, buf251, 78400, 49, grid=grid(78400), stream=stream0)
del arg117_1
del arg118_1
buf252 = reinterpret_tensor(buf246, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf246 # reuse
# Source Nodes: [matmul_15], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf244, arg116_1, buf252, 2508800, grid=grid(2508800), stream=stream0)
del arg116_1
buf253 = reinterpret_tensor(buf242, (1600, 49, 32), (1568, 32, 1), 0); del buf242 # reuse
# Source Nodes: [matmul_15], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf251, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf252, (1600, 49, 32), (1600, 32, 1), 0), out=buf253)
buf254 = buf225; del buf225 # reuse
# Source Nodes: [x_152], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf253, buf254, 2508800, grid=grid(2508800), stream=stream0)
buf255 = buf226; del buf226 # reuse
# Source Nodes: [x_153], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg119_1, buf255, 262144, grid=grid(262144), stream=stream0)
del arg119_1
buf256 = reinterpret_tensor(buf253, (4900, 512), (512, 1), 0); del buf253 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf254, (4900, 512), (512, 1), 0), reinterpret_tensor(buf255, (512, 512), (1, 512), 0), out=buf256)
buf261 = buf210; del buf210 # reuse
# Source Nodes: [layer_norm_20, x_160, x_161], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf237, buf256, arg120_1, arg121_1, arg122_1, buf261, 4624, 512, grid=grid(4624), stream=stream0)
del arg121_1
del arg122_1
buf262 = reinterpret_tensor(buf235, (2048, 512), (512, 1), 0); del buf235 # reuse
# Source Nodes: [x_161], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg123_1, buf262, 1048576, grid=grid(1048576), stream=stream0)
del arg123_1
buf263 = reinterpret_tensor(buf234, (4624, 2048), (2048, 1), 0); del buf234 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf261, (4624, 512), (512, 1), 0), reinterpret_tensor(buf262, (512, 2048), (1, 512), 0), out=buf263)
buf264 = reinterpret_tensor(buf263, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf263 # reuse
# Source Nodes: [x_162], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf264, arg124_1, 9469952, grid=grid(9469952), stream=stream0)
del arg124_1
buf265 = reinterpret_tensor(buf262, (512, 2048), (2048, 1), 0); del buf262 # reuse
# Source Nodes: [x_164], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg125_1, buf265, 1048576, grid=grid(1048576), stream=stream0)
del arg125_1
buf266 = reinterpret_tensor(buf261, (4624, 512), (512, 1), 0); del buf261 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf264, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf265, (2048, 512), (1, 2048), 0), out=buf266)
buf267 = reinterpret_tensor(buf266, (1, 4624, 512), (2367488, 512, 1), 0); del buf266 # reuse
buf268 = buf239; del buf239 # reuse
buf269 = buf238; del buf238 # reuse
# Source Nodes: [x_160, x_166, x_167], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf267, buf237, buf256, arg120_1, arg126_1, buf268, buf269, 4624, 512, grid=grid(4624), stream=stream0)
del arg120_1
del arg126_1
buf271 = reinterpret_tensor(buf256, (100, 49, 512), (25088, 512, 1), 0); del buf256 # reuse
# Source Nodes: [linear_34], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf267, buf268, buf269, arg127_1, arg128_1, buf271, 2508800, grid=grid(2508800), stream=stream0)
del arg127_1
del arg128_1
buf272 = buf243; del buf243 # reuse
# Source Nodes: [linear_34], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg129_1, buf272, 786432, grid=grid(786432), stream=stream0)
del arg129_1
buf273 = buf244; del buf244 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf271, (4900, 512), (512, 1), 0), reinterpret_tensor(buf272, (512, 1536), (1, 512), 0), out=buf273)
buf274 = buf252; del buf252 # reuse
# Source Nodes: [attn_40, q_17], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf273, arg130_1, buf274, 2508800, grid=grid(2508800), stream=stream0)
buf275 = reinterpret_tensor(buf245, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf245 # reuse
# Source Nodes: [attn_40], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf273, arg130_1, buf275, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf276 = buf247; del buf247 # reuse
# Source Nodes: [attn_40], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf274, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf275, (1600, 32, 49), (1600, 49, 1), 0), out=buf276)
buf279 = buf251; del buf251 # reuse
# Source Nodes: [attn_41, attn_42, matmul_17], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf276, arg132_1, arg131_1, buf279, 78400, 49, grid=grid(78400), stream=stream0)
del arg131_1
del arg132_1
buf280 = reinterpret_tensor(buf275, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf275 # reuse
# Source Nodes: [matmul_17], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf273, arg130_1, buf280, 2508800, grid=grid(2508800), stream=stream0)
del arg130_1
buf281 = reinterpret_tensor(buf271, (1600, 49, 32), (1568, 32, 1), 0); del buf271 # reuse
# Source Nodes: [matmul_17], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf279, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf280, (1600, 49, 32), (1600, 32, 1), 0), out=buf281)
buf282 = buf254; del buf254 # reuse
# Source Nodes: [x_171], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf281, buf282, 2508800, grid=grid(2508800), stream=stream0)
buf283 = buf255; del buf255 # reuse
# Source Nodes: [x_172], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg133_1, buf283, 262144, grid=grid(262144), stream=stream0)
del arg133_1
buf284 = reinterpret_tensor(buf281, (4900, 512), (512, 1), 0); del buf281 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf282, (4900, 512), (512, 1), 0), reinterpret_tensor(buf283, (512, 512), (1, 512), 0), out=buf284)
buf288 = buf237; del buf237 # reuse
# Source Nodes: [layer_norm_22, x_178, x_179], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf267, buf284, arg134_1, arg135_1, arg136_1, buf288, 4624, 512, grid=grid(4624), stream=stream0)
del arg135_1
del arg136_1
buf289 = reinterpret_tensor(buf265, (2048, 512), (512, 1), 0); del buf265 # reuse
# Source Nodes: [x_179], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg137_1, buf289, 1048576, grid=grid(1048576), stream=stream0)
del arg137_1
buf290 = reinterpret_tensor(buf264, (4624, 2048), (2048, 1), 0); del buf264 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf288, (4624, 512), (512, 1), 0), reinterpret_tensor(buf289, (512, 2048), (1, 512), 0), out=buf290)
buf291 = reinterpret_tensor(buf290, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf290 # reuse
# Source Nodes: [x_180], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf291, arg138_1, 9469952, grid=grid(9469952), stream=stream0)
del arg138_1
buf292 = reinterpret_tensor(buf289, (512, 2048), (2048, 1), 0); del buf289 # reuse
# Source Nodes: [x_182], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg139_1, buf292, 1048576, grid=grid(1048576), stream=stream0)
del arg139_1
buf293 = reinterpret_tensor(buf288, (4624, 512), (512, 1), 0); del buf288 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf291, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf292, (2048, 512), (1, 2048), 0), out=buf293)
buf294 = reinterpret_tensor(buf293, (1, 4624, 512), (2367488, 512, 1), 0); del buf293 # reuse
buf295 = buf269; del buf269 # reuse
buf296 = buf268; del buf268 # reuse
# Source Nodes: [x_178, x_184, x_185], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf294, buf267, buf284, arg134_1, arg140_1, buf295, buf296, 4624, 512, grid=grid(4624), stream=stream0)
del arg134_1
del arg140_1
buf298 = buf241; del buf241 # reuse
# Source Nodes: [shifted_x_4, x_187], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf294, buf295, buf296, arg141_1, arg142_1, buf298, 2508800, grid=grid(2508800), stream=stream0)
del arg141_1
del arg142_1
buf299 = reinterpret_tensor(buf284, (100, 49, 512), (25088, 512, 1), 0); del buf284 # reuse
# Source Nodes: [linear_38], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf298, buf299, 2508800, grid=grid(2508800), stream=stream0)
buf300 = buf272; del buf272 # reuse
# Source Nodes: [linear_38], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg143_1, buf300, 786432, grid=grid(786432), stream=stream0)
del arg143_1
buf301 = buf273; del buf273 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf299, (4900, 512), (512, 1), 0), reinterpret_tensor(buf300, (512, 1536), (1, 512), 0), out=buf301)
buf302 = buf280; del buf280 # reuse
# Source Nodes: [attn_44, q_19], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf301, arg144_1, buf302, 2508800, grid=grid(2508800), stream=stream0)
buf303 = reinterpret_tensor(buf274, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf274 # reuse
# Source Nodes: [attn_44], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf301, arg144_1, buf303, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf304 = buf276; del buf276 # reuse
# Source Nodes: [attn_44], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf302, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf303, (1600, 32, 49), (1600, 49, 1), 0), out=buf304)
buf308 = buf279; del buf279 # reuse
# Source Nodes: [attn_46, attn_48, matmul_19], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf304, arg146_1, arg145_1, buf190, buf308, 78400, 49, grid=grid(78400), stream=stream0)
del arg145_1
del arg146_1
buf309 = reinterpret_tensor(buf303, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf303 # reuse
# Source Nodes: [matmul_19], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf301, arg144_1, buf309, 2508800, grid=grid(2508800), stream=stream0)
del arg144_1
buf310 = reinterpret_tensor(buf299, (1600, 49, 32), (1568, 32, 1), 0); del buf299 # reuse
# Source Nodes: [matmul_19], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf308, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf309, (1600, 49, 32), (1600, 32, 1), 0), out=buf310)
buf311 = buf282; del buf282 # reuse
# Source Nodes: [x_189], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf310, buf311, 2508800, grid=grid(2508800), stream=stream0)
buf312 = buf283; del buf283 # reuse
# Source Nodes: [x_190], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg147_1, buf312, 262144, grid=grid(262144), stream=stream0)
del arg147_1
buf313 = reinterpret_tensor(buf310, (4900, 512), (512, 1), 0); del buf310 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf311, (4900, 512), (512, 1), 0), reinterpret_tensor(buf312, (512, 512), (1, 512), 0), out=buf313)
buf318 = buf267; del buf267 # reuse
# Source Nodes: [layer_norm_24, x_197, x_198], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf294, buf313, arg148_1, arg149_1, arg150_1, buf318, 4624, 512, grid=grid(4624), stream=stream0)
del arg149_1
del arg150_1
buf319 = reinterpret_tensor(buf292, (2048, 512), (512, 1), 0); del buf292 # reuse
# Source Nodes: [x_198], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg151_1, buf319, 1048576, grid=grid(1048576), stream=stream0)
del arg151_1
buf320 = reinterpret_tensor(buf291, (4624, 2048), (2048, 1), 0); del buf291 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf318, (4624, 512), (512, 1), 0), reinterpret_tensor(buf319, (512, 2048), (1, 512), 0), out=buf320)
buf321 = reinterpret_tensor(buf320, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf320 # reuse
# Source Nodes: [x_199], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf321, arg152_1, 9469952, grid=grid(9469952), stream=stream0)
del arg152_1
buf322 = reinterpret_tensor(buf319, (512, 2048), (2048, 1), 0); del buf319 # reuse
# Source Nodes: [x_201], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg153_1, buf322, 1048576, grid=grid(1048576), stream=stream0)
del arg153_1
buf323 = reinterpret_tensor(buf318, (4624, 512), (512, 1), 0); del buf318 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf321, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf322, (2048, 512), (1, 2048), 0), out=buf323)
buf324 = reinterpret_tensor(buf323, (1, 4624, 512), (2367488, 512, 1), 0); del buf323 # reuse
buf325 = buf296; del buf296 # reuse
buf326 = buf295; del buf295 # reuse
# Source Nodes: [x_197, x_203, x_204], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf324, buf294, buf313, arg148_1, arg154_1, buf325, buf326, 4624, 512, grid=grid(4624), stream=stream0)
del arg148_1
del arg154_1
buf328 = reinterpret_tensor(buf313, (100, 49, 512), (25088, 512, 1), 0); del buf313 # reuse
# Source Nodes: [linear_42], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf324, buf325, buf326, arg155_1, arg156_1, buf328, 2508800, grid=grid(2508800), stream=stream0)
del arg155_1
del arg156_1
buf329 = buf300; del buf300 # reuse
# Source Nodes: [linear_42], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg157_1, buf329, 786432, grid=grid(786432), stream=stream0)
del arg157_1
buf330 = buf301; del buf301 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf328, (4900, 512), (512, 1), 0), reinterpret_tensor(buf329, (512, 1536), (1, 512), 0), out=buf330)
buf331 = buf309; del buf309 # reuse
# Source Nodes: [attn_50, q_21], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf330, arg158_1, buf331, 2508800, grid=grid(2508800), stream=stream0)
buf332 = reinterpret_tensor(buf302, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf302 # reuse
# Source Nodes: [attn_50], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf330, arg158_1, buf332, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf333 = buf304; del buf304 # reuse
# Source Nodes: [attn_50], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf331, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf332, (1600, 32, 49), (1600, 49, 1), 0), out=buf333)
buf336 = buf308; del buf308 # reuse
# Source Nodes: [attn_51, attn_52, matmul_21], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf333, arg160_1, arg159_1, buf336, 78400, 49, grid=grid(78400), stream=stream0)
del arg159_1
del arg160_1
buf337 = reinterpret_tensor(buf332, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf332 # reuse
# Source Nodes: [matmul_21], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf330, arg158_1, buf337, 2508800, grid=grid(2508800), stream=stream0)
del arg158_1
buf338 = reinterpret_tensor(buf328, (1600, 49, 32), (1568, 32, 1), 0); del buf328 # reuse
# Source Nodes: [matmul_21], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf336, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf337, (1600, 49, 32), (1600, 32, 1), 0), out=buf338)
buf339 = buf311; del buf311 # reuse
# Source Nodes: [x_208], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf338, buf339, 2508800, grid=grid(2508800), stream=stream0)
buf340 = buf312; del buf312 # reuse
# Source Nodes: [x_209], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg161_1, buf340, 262144, grid=grid(262144), stream=stream0)
del arg161_1
buf341 = reinterpret_tensor(buf338, (4900, 512), (512, 1), 0); del buf338 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf339, (4900, 512), (512, 1), 0), reinterpret_tensor(buf340, (512, 512), (1, 512), 0), out=buf341)
buf345 = buf294; del buf294 # reuse
# Source Nodes: [layer_norm_26, x_215, x_216], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf324, buf341, arg162_1, arg163_1, arg164_1, buf345, 4624, 512, grid=grid(4624), stream=stream0)
del arg163_1
del arg164_1
buf346 = reinterpret_tensor(buf322, (2048, 512), (512, 1), 0); del buf322 # reuse
# Source Nodes: [x_216], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg165_1, buf346, 1048576, grid=grid(1048576), stream=stream0)
del arg165_1
buf347 = reinterpret_tensor(buf321, (4624, 2048), (2048, 1), 0); del buf321 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf345, (4624, 512), (512, 1), 0), reinterpret_tensor(buf346, (512, 2048), (1, 512), 0), out=buf347)
buf348 = reinterpret_tensor(buf347, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf347 # reuse
# Source Nodes: [x_217], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf348, arg166_1, 9469952, grid=grid(9469952), stream=stream0)
del arg166_1
buf349 = reinterpret_tensor(buf346, (512, 2048), (2048, 1), 0); del buf346 # reuse
# Source Nodes: [x_219], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg167_1, buf349, 1048576, grid=grid(1048576), stream=stream0)
del arg167_1
buf350 = reinterpret_tensor(buf345, (4624, 512), (512, 1), 0); del buf345 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf348, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf349, (2048, 512), (1, 2048), 0), out=buf350)
buf351 = reinterpret_tensor(buf350, (1, 4624, 512), (2367488, 512, 1), 0); del buf350 # reuse
buf352 = buf326; del buf326 # reuse
buf353 = buf325; del buf325 # reuse
# Source Nodes: [x_215, x_221, x_222], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf351, buf324, buf341, arg162_1, arg168_1, buf352, buf353, 4624, 512, grid=grid(4624), stream=stream0)
del arg162_1
del arg168_1
buf355 = buf298; del buf298 # reuse
# Source Nodes: [shifted_x_5, x_224], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf351, buf352, buf353, arg169_1, arg170_1, buf355, 2508800, grid=grid(2508800), stream=stream0)
del arg169_1
del arg170_1
buf356 = reinterpret_tensor(buf341, (100, 49, 512), (25088, 512, 1), 0); del buf341 # reuse
# Source Nodes: [linear_46], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf355, buf356, 2508800, grid=grid(2508800), stream=stream0)
buf357 = buf329; del buf329 # reuse
# Source Nodes: [linear_46], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg171_1, buf357, 786432, grid=grid(786432), stream=stream0)
del arg171_1
buf358 = buf330; del buf330 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf356, (4900, 512), (512, 1), 0), reinterpret_tensor(buf357, (512, 1536), (1, 512), 0), out=buf358)
buf359 = buf337; del buf337 # reuse
# Source Nodes: [attn_54, q_23], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf358, arg172_1, buf359, 2508800, grid=grid(2508800), stream=stream0)
buf360 = reinterpret_tensor(buf331, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf331 # reuse
# Source Nodes: [attn_54], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf358, arg172_1, buf360, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf361 = buf333; del buf333 # reuse
# Source Nodes: [attn_54], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf359, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf360, (1600, 32, 49), (1600, 49, 1), 0), out=buf361)
buf365 = buf336; del buf336 # reuse
# Source Nodes: [attn_56, attn_58, matmul_23], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf361, arg174_1, arg173_1, buf190, buf365, 78400, 49, grid=grid(78400), stream=stream0)
del arg173_1
del arg174_1
buf366 = reinterpret_tensor(buf360, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf360 # reuse
# Source Nodes: [matmul_23], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf358, arg172_1, buf366, 2508800, grid=grid(2508800), stream=stream0)
del arg172_1
buf367 = reinterpret_tensor(buf356, (1600, 49, 32), (1568, 32, 1), 0); del buf356 # reuse
# Source Nodes: [matmul_23], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf365, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf366, (1600, 49, 32), (1600, 32, 1), 0), out=buf367)
buf368 = buf339; del buf339 # reuse
# Source Nodes: [x_226], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf367, buf368, 2508800, grid=grid(2508800), stream=stream0)
buf369 = buf340; del buf340 # reuse
# Source Nodes: [x_227], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg175_1, buf369, 262144, grid=grid(262144), stream=stream0)
del arg175_1
buf370 = reinterpret_tensor(buf367, (4900, 512), (512, 1), 0); del buf367 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf368, (4900, 512), (512, 1), 0), reinterpret_tensor(buf369, (512, 512), (1, 512), 0), out=buf370)
buf375 = buf324; del buf324 # reuse
# Source Nodes: [layer_norm_28, x_234, x_235], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf351, buf370, arg176_1, arg177_1, arg178_1, buf375, 4624, 512, grid=grid(4624), stream=stream0)
del arg177_1
del arg178_1
buf376 = reinterpret_tensor(buf349, (2048, 512), (512, 1), 0); del buf349 # reuse
# Source Nodes: [x_235], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg179_1, buf376, 1048576, grid=grid(1048576), stream=stream0)
del arg179_1
buf377 = reinterpret_tensor(buf348, (4624, 2048), (2048, 1), 0); del buf348 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf375, (4624, 512), (512, 1), 0), reinterpret_tensor(buf376, (512, 2048), (1, 512), 0), out=buf377)
buf378 = reinterpret_tensor(buf377, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf377 # reuse
# Source Nodes: [x_236], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf378, arg180_1, 9469952, grid=grid(9469952), stream=stream0)
del arg180_1
buf379 = reinterpret_tensor(buf376, (512, 2048), (2048, 1), 0); del buf376 # reuse
# Source Nodes: [x_238], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg181_1, buf379, 1048576, grid=grid(1048576), stream=stream0)
del arg181_1
buf380 = reinterpret_tensor(buf375, (4624, 512), (512, 1), 0); del buf375 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf378, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf379, (2048, 512), (1, 2048), 0), out=buf380)
buf381 = reinterpret_tensor(buf380, (1, 4624, 512), (2367488, 512, 1), 0); del buf380 # reuse
buf382 = buf353; del buf353 # reuse
buf383 = buf352; del buf352 # reuse
# Source Nodes: [x_234, x_240, x_241], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf381, buf351, buf370, arg176_1, arg182_1, buf382, buf383, 4624, 512, grid=grid(4624), stream=stream0)
del arg176_1
del arg182_1
buf385 = reinterpret_tensor(buf370, (100, 49, 512), (25088, 512, 1), 0); del buf370 # reuse
# Source Nodes: [linear_50], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf381, buf382, buf383, arg183_1, arg184_1, buf385, 2508800, grid=grid(2508800), stream=stream0)
del arg183_1
del arg184_1
buf386 = buf357; del buf357 # reuse
# Source Nodes: [linear_50], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg185_1, buf386, 786432, grid=grid(786432), stream=stream0)
del arg185_1
buf387 = buf358; del buf358 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf385, (4900, 512), (512, 1), 0), reinterpret_tensor(buf386, (512, 1536), (1, 512), 0), out=buf387)
buf388 = buf366; del buf366 # reuse
# Source Nodes: [attn_60, q_25], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf387, arg186_1, buf388, 2508800, grid=grid(2508800), stream=stream0)
buf389 = reinterpret_tensor(buf359, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf359 # reuse
# Source Nodes: [attn_60], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf387, arg186_1, buf389, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf390 = buf361; del buf361 # reuse
# Source Nodes: [attn_60], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf388, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf389, (1600, 32, 49), (1600, 49, 1), 0), out=buf390)
buf393 = buf365; del buf365 # reuse
# Source Nodes: [attn_61, attn_62, matmul_25], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf390, arg188_1, arg187_1, buf393, 78400, 49, grid=grid(78400), stream=stream0)
del arg187_1
del arg188_1
buf394 = reinterpret_tensor(buf389, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf389 # reuse
# Source Nodes: [matmul_25], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf387, arg186_1, buf394, 2508800, grid=grid(2508800), stream=stream0)
del arg186_1
buf395 = reinterpret_tensor(buf385, (1600, 49, 32), (1568, 32, 1), 0); del buf385 # reuse
# Source Nodes: [matmul_25], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf393, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf394, (1600, 49, 32), (1600, 32, 1), 0), out=buf395)
buf396 = buf368; del buf368 # reuse
# Source Nodes: [x_245], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf395, buf396, 2508800, grid=grid(2508800), stream=stream0)
buf397 = buf369; del buf369 # reuse
# Source Nodes: [x_246], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg189_1, buf397, 262144, grid=grid(262144), stream=stream0)
del arg189_1
buf398 = reinterpret_tensor(buf395, (4900, 512), (512, 1), 0); del buf395 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf396, (4900, 512), (512, 1), 0), reinterpret_tensor(buf397, (512, 512), (1, 512), 0), out=buf398)
buf402 = buf351; del buf351 # reuse
# Source Nodes: [layer_norm_30, x_252, x_253], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf381, buf398, arg190_1, arg191_1, arg192_1, buf402, 4624, 512, grid=grid(4624), stream=stream0)
del arg191_1
del arg192_1
buf403 = reinterpret_tensor(buf379, (2048, 512), (512, 1), 0); del buf379 # reuse
# Source Nodes: [x_253], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg193_1, buf403, 1048576, grid=grid(1048576), stream=stream0)
del arg193_1
buf404 = reinterpret_tensor(buf378, (4624, 2048), (2048, 1), 0); del buf378 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf402, (4624, 512), (512, 1), 0), reinterpret_tensor(buf403, (512, 2048), (1, 512), 0), out=buf404)
buf405 = reinterpret_tensor(buf404, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf404 # reuse
# Source Nodes: [x_254], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf405, arg194_1, 9469952, grid=grid(9469952), stream=stream0)
del arg194_1
buf406 = reinterpret_tensor(buf403, (512, 2048), (2048, 1), 0); del buf403 # reuse
# Source Nodes: [x_256], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg195_1, buf406, 1048576, grid=grid(1048576), stream=stream0)
del arg195_1
buf407 = reinterpret_tensor(buf402, (4624, 512), (512, 1), 0); del buf402 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf405, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf406, (2048, 512), (1, 2048), 0), out=buf407)
buf408 = reinterpret_tensor(buf407, (1, 4624, 512), (2367488, 512, 1), 0); del buf407 # reuse
buf409 = buf383; del buf383 # reuse
buf410 = buf382; del buf382 # reuse
# Source Nodes: [x_252, x_258, x_259], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf408, buf381, buf398, arg190_1, arg196_1, buf409, buf410, 4624, 512, grid=grid(4624), stream=stream0)
del arg190_1
del arg196_1
buf412 = buf355; del buf355 # reuse
# Source Nodes: [shifted_x_6, x_261], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf408, buf409, buf410, arg197_1, arg198_1, buf412, 2508800, grid=grid(2508800), stream=stream0)
del arg197_1
del arg198_1
buf413 = reinterpret_tensor(buf398, (100, 49, 512), (25088, 512, 1), 0); del buf398 # reuse
# Source Nodes: [linear_54], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf412, buf413, 2508800, grid=grid(2508800), stream=stream0)
buf414 = buf386; del buf386 # reuse
# Source Nodes: [linear_54], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg199_1, buf414, 786432, grid=grid(786432), stream=stream0)
del arg199_1
buf415 = buf387; del buf387 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf413, (4900, 512), (512, 1), 0), reinterpret_tensor(buf414, (512, 1536), (1, 512), 0), out=buf415)
buf416 = buf394; del buf394 # reuse
# Source Nodes: [attn_64, q_27], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf415, arg200_1, buf416, 2508800, grid=grid(2508800), stream=stream0)
buf417 = reinterpret_tensor(buf388, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf388 # reuse
# Source Nodes: [attn_64], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf415, arg200_1, buf417, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf418 = buf390; del buf390 # reuse
# Source Nodes: [attn_64], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf416, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf417, (1600, 32, 49), (1600, 49, 1), 0), out=buf418)
buf422 = buf393; del buf393 # reuse
# Source Nodes: [attn_66, attn_68, matmul_27], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf418, arg202_1, arg201_1, buf190, buf422, 78400, 49, grid=grid(78400), stream=stream0)
del arg201_1
del arg202_1
buf423 = reinterpret_tensor(buf417, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf417 # reuse
# Source Nodes: [matmul_27], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf415, arg200_1, buf423, 2508800, grid=grid(2508800), stream=stream0)
del arg200_1
buf424 = reinterpret_tensor(buf413, (1600, 49, 32), (1568, 32, 1), 0); del buf413 # reuse
# Source Nodes: [matmul_27], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf422, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf423, (1600, 49, 32), (1600, 32, 1), 0), out=buf424)
buf425 = buf396; del buf396 # reuse
# Source Nodes: [x_263], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf424, buf425, 2508800, grid=grid(2508800), stream=stream0)
buf426 = buf397; del buf397 # reuse
# Source Nodes: [x_264], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg203_1, buf426, 262144, grid=grid(262144), stream=stream0)
del arg203_1
buf427 = reinterpret_tensor(buf424, (4900, 512), (512, 1), 0); del buf424 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf425, (4900, 512), (512, 1), 0), reinterpret_tensor(buf426, (512, 512), (1, 512), 0), out=buf427)
buf432 = buf381; del buf381 # reuse
# Source Nodes: [layer_norm_32, x_271, x_272], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf408, buf427, arg204_1, arg205_1, arg206_1, buf432, 4624, 512, grid=grid(4624), stream=stream0)
del arg205_1
del arg206_1
buf433 = reinterpret_tensor(buf406, (2048, 512), (512, 1), 0); del buf406 # reuse
# Source Nodes: [x_272], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg207_1, buf433, 1048576, grid=grid(1048576), stream=stream0)
del arg207_1
buf434 = reinterpret_tensor(buf405, (4624, 2048), (2048, 1), 0); del buf405 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf432, (4624, 512), (512, 1), 0), reinterpret_tensor(buf433, (512, 2048), (1, 512), 0), out=buf434)
buf435 = reinterpret_tensor(buf434, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf434 # reuse
# Source Nodes: [x_273], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf435, arg208_1, 9469952, grid=grid(9469952), stream=stream0)
del arg208_1
buf436 = reinterpret_tensor(buf433, (512, 2048), (2048, 1), 0); del buf433 # reuse
# Source Nodes: [x_275], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg209_1, buf436, 1048576, grid=grid(1048576), stream=stream0)
del arg209_1
buf437 = reinterpret_tensor(buf432, (4624, 512), (512, 1), 0); del buf432 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf435, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf436, (2048, 512), (1, 2048), 0), out=buf437)
buf438 = reinterpret_tensor(buf437, (1, 4624, 512), (2367488, 512, 1), 0); del buf437 # reuse
buf439 = buf410; del buf410 # reuse
buf440 = buf409; del buf409 # reuse
# Source Nodes: [x_271, x_277, x_278], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf438, buf408, buf427, arg204_1, arg210_1, buf439, buf440, 4624, 512, grid=grid(4624), stream=stream0)
del arg204_1
del arg210_1
buf442 = reinterpret_tensor(buf427, (100, 49, 512), (25088, 512, 1), 0); del buf427 # reuse
# Source Nodes: [linear_58], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf438, buf439, buf440, arg211_1, arg212_1, buf442, 2508800, grid=grid(2508800), stream=stream0)
del arg211_1
del arg212_1
buf443 = buf414; del buf414 # reuse
# Source Nodes: [linear_58], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg213_1, buf443, 786432, grid=grid(786432), stream=stream0)
del arg213_1
buf444 = buf415; del buf415 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf442, (4900, 512), (512, 1), 0), reinterpret_tensor(buf443, (512, 1536), (1, 512), 0), out=buf444)
buf445 = buf423; del buf423 # reuse
# Source Nodes: [attn_70, q_29], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf444, arg214_1, buf445, 2508800, grid=grid(2508800), stream=stream0)
buf446 = reinterpret_tensor(buf416, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf416 # reuse
# Source Nodes: [attn_70], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf444, arg214_1, buf446, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf447 = buf418; del buf418 # reuse
# Source Nodes: [attn_70], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf445, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf446, (1600, 32, 49), (1600, 49, 1), 0), out=buf447)
buf450 = buf422; del buf422 # reuse
# Source Nodes: [attn_71, attn_72, matmul_29], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf447, arg216_1, arg215_1, buf450, 78400, 49, grid=grid(78400), stream=stream0)
del arg215_1
del arg216_1
buf451 = reinterpret_tensor(buf446, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf446 # reuse
# Source Nodes: [matmul_29], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf444, arg214_1, buf451, 2508800, grid=grid(2508800), stream=stream0)
del arg214_1
buf452 = reinterpret_tensor(buf442, (1600, 49, 32), (1568, 32, 1), 0); del buf442 # reuse
# Source Nodes: [matmul_29], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf450, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf451, (1600, 49, 32), (1600, 32, 1), 0), out=buf452)
buf453 = buf425; del buf425 # reuse
# Source Nodes: [x_282], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf452, buf453, 2508800, grid=grid(2508800), stream=stream0)
buf454 = buf426; del buf426 # reuse
# Source Nodes: [x_283], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg217_1, buf454, 262144, grid=grid(262144), stream=stream0)
del arg217_1
buf455 = reinterpret_tensor(buf452, (4900, 512), (512, 1), 0); del buf452 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf453, (4900, 512), (512, 1), 0), reinterpret_tensor(buf454, (512, 512), (1, 512), 0), out=buf455)
buf459 = buf408; del buf408 # reuse
# Source Nodes: [layer_norm_34, x_289, x_290], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf438, buf455, arg218_1, arg219_1, arg220_1, buf459, 4624, 512, grid=grid(4624), stream=stream0)
del arg219_1
del arg220_1
buf460 = reinterpret_tensor(buf436, (2048, 512), (512, 1), 0); del buf436 # reuse
# Source Nodes: [x_290], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg221_1, buf460, 1048576, grid=grid(1048576), stream=stream0)
del arg221_1
buf461 = reinterpret_tensor(buf435, (4624, 2048), (2048, 1), 0); del buf435 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf459, (4624, 512), (512, 1), 0), reinterpret_tensor(buf460, (512, 2048), (1, 512), 0), out=buf461)
buf462 = reinterpret_tensor(buf461, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf461 # reuse
# Source Nodes: [x_291], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf462, arg222_1, 9469952, grid=grid(9469952), stream=stream0)
del arg222_1
buf463 = reinterpret_tensor(buf460, (512, 2048), (2048, 1), 0); del buf460 # reuse
# Source Nodes: [x_293], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg223_1, buf463, 1048576, grid=grid(1048576), stream=stream0)
del arg223_1
buf464 = reinterpret_tensor(buf459, (4624, 512), (512, 1), 0); del buf459 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf462, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf463, (2048, 512), (1, 2048), 0), out=buf464)
buf465 = reinterpret_tensor(buf464, (1, 4624, 512), (2367488, 512, 1), 0); del buf464 # reuse
buf466 = buf440; del buf440 # reuse
buf467 = buf439; del buf439 # reuse
# Source Nodes: [x_289, x_295, x_296], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf465, buf438, buf455, arg218_1, arg224_1, buf466, buf467, 4624, 512, grid=grid(4624), stream=stream0)
del arg218_1
del arg224_1
buf469 = buf412; del buf412 # reuse
# Source Nodes: [shifted_x_7, x_298], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf465, buf466, buf467, arg225_1, arg226_1, buf469, 2508800, grid=grid(2508800), stream=stream0)
del arg225_1
del arg226_1
buf470 = reinterpret_tensor(buf455, (100, 49, 512), (25088, 512, 1), 0); del buf455 # reuse
# Source Nodes: [linear_62], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf469, buf470, 2508800, grid=grid(2508800), stream=stream0)
buf471 = buf443; del buf443 # reuse
# Source Nodes: [linear_62], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg227_1, buf471, 786432, grid=grid(786432), stream=stream0)
del arg227_1
buf472 = buf444; del buf444 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf470, (4900, 512), (512, 1), 0), reinterpret_tensor(buf471, (512, 1536), (1, 512), 0), out=buf472)
buf473 = buf451; del buf451 # reuse
# Source Nodes: [attn_74, q_31], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf472, arg228_1, buf473, 2508800, grid=grid(2508800), stream=stream0)
buf474 = reinterpret_tensor(buf445, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf445 # reuse
# Source Nodes: [attn_74], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf472, arg228_1, buf474, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf475 = buf447; del buf447 # reuse
# Source Nodes: [attn_74], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf473, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf474, (1600, 32, 49), (1600, 49, 1), 0), out=buf475)
buf479 = buf450; del buf450 # reuse
# Source Nodes: [attn_76, attn_78, matmul_31], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf475, arg230_1, arg229_1, buf190, buf479, 78400, 49, grid=grid(78400), stream=stream0)
del arg229_1
del arg230_1
buf480 = reinterpret_tensor(buf474, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf474 # reuse
# Source Nodes: [matmul_31], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf472, arg228_1, buf480, 2508800, grid=grid(2508800), stream=stream0)
del arg228_1
buf481 = reinterpret_tensor(buf470, (1600, 49, 32), (1568, 32, 1), 0); del buf470 # reuse
# Source Nodes: [matmul_31], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf479, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf480, (1600, 49, 32), (1600, 32, 1), 0), out=buf481)
buf482 = buf453; del buf453 # reuse
# Source Nodes: [x_300], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf481, buf482, 2508800, grid=grid(2508800), stream=stream0)
buf483 = buf454; del buf454 # reuse
# Source Nodes: [x_301], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg231_1, buf483, 262144, grid=grid(262144), stream=stream0)
del arg231_1
buf484 = reinterpret_tensor(buf481, (4900, 512), (512, 1), 0); del buf481 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf482, (4900, 512), (512, 1), 0), reinterpret_tensor(buf483, (512, 512), (1, 512), 0), out=buf484)
buf489 = buf438; del buf438 # reuse
# Source Nodes: [layer_norm_36, x_308, x_309], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf465, buf484, arg232_1, arg233_1, arg234_1, buf489, 4624, 512, grid=grid(4624), stream=stream0)
del arg233_1
del arg234_1
buf490 = reinterpret_tensor(buf463, (2048, 512), (512, 1), 0); del buf463 # reuse
# Source Nodes: [x_309], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg235_1, buf490, 1048576, grid=grid(1048576), stream=stream0)
del arg235_1
buf491 = reinterpret_tensor(buf462, (4624, 2048), (2048, 1), 0); del buf462 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf489, (4624, 512), (512, 1), 0), reinterpret_tensor(buf490, (512, 2048), (1, 512), 0), out=buf491)
buf492 = reinterpret_tensor(buf491, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf491 # reuse
# Source Nodes: [x_310], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf492, arg236_1, 9469952, grid=grid(9469952), stream=stream0)
del arg236_1
buf493 = reinterpret_tensor(buf490, (512, 2048), (2048, 1), 0); del buf490 # reuse
# Source Nodes: [x_312], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg237_1, buf493, 1048576, grid=grid(1048576), stream=stream0)
del arg237_1
buf494 = reinterpret_tensor(buf489, (4624, 512), (512, 1), 0); del buf489 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf492, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf493, (2048, 512), (1, 2048), 0), out=buf494)
buf495 = reinterpret_tensor(buf494, (1, 4624, 512), (2367488, 512, 1), 0); del buf494 # reuse
buf496 = buf467; del buf467 # reuse
buf497 = buf466; del buf466 # reuse
# Source Nodes: [x_308, x_314, x_315], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf495, buf465, buf484, arg232_1, arg238_1, buf496, buf497, 4624, 512, grid=grid(4624), stream=stream0)
del arg232_1
del arg238_1
buf499 = reinterpret_tensor(buf484, (100, 49, 512), (25088, 512, 1), 0); del buf484 # reuse
# Source Nodes: [linear_66], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf495, buf496, buf497, arg239_1, arg240_1, buf499, 2508800, grid=grid(2508800), stream=stream0)
del arg239_1
del arg240_1
buf500 = buf471; del buf471 # reuse
# Source Nodes: [linear_66], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg241_1, buf500, 786432, grid=grid(786432), stream=stream0)
del arg241_1
buf501 = buf472; del buf472 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf499, (4900, 512), (512, 1), 0), reinterpret_tensor(buf500, (512, 1536), (1, 512), 0), out=buf501)
buf502 = buf480; del buf480 # reuse
# Source Nodes: [attn_80, q_33], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf501, arg242_1, buf502, 2508800, grid=grid(2508800), stream=stream0)
buf503 = reinterpret_tensor(buf473, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf473 # reuse
# Source Nodes: [attn_80], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf501, arg242_1, buf503, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf504 = buf475; del buf475 # reuse
# Source Nodes: [attn_80], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf502, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf503, (1600, 32, 49), (1600, 49, 1), 0), out=buf504)
buf507 = buf479; del buf479 # reuse
# Source Nodes: [attn_81, attn_82, matmul_33], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf504, arg244_1, arg243_1, buf507, 78400, 49, grid=grid(78400), stream=stream0)
del arg243_1
del arg244_1
buf508 = reinterpret_tensor(buf503, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf503 # reuse
# Source Nodes: [matmul_33], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf501, arg242_1, buf508, 2508800, grid=grid(2508800), stream=stream0)
del arg242_1
buf509 = reinterpret_tensor(buf499, (1600, 49, 32), (1568, 32, 1), 0); del buf499 # reuse
# Source Nodes: [matmul_33], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf507, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf508, (1600, 49, 32), (1600, 32, 1), 0), out=buf509)
buf510 = buf482; del buf482 # reuse
# Source Nodes: [x_319], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf509, buf510, 2508800, grid=grid(2508800), stream=stream0)
buf511 = buf483; del buf483 # reuse
# Source Nodes: [x_320], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg245_1, buf511, 262144, grid=grid(262144), stream=stream0)
del arg245_1
buf512 = reinterpret_tensor(buf509, (4900, 512), (512, 1), 0); del buf509 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf510, (4900, 512), (512, 1), 0), reinterpret_tensor(buf511, (512, 512), (1, 512), 0), out=buf512)
buf516 = buf465; del buf465 # reuse
# Source Nodes: [layer_norm_38, x_326, x_327], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf495, buf512, arg246_1, arg247_1, arg248_1, buf516, 4624, 512, grid=grid(4624), stream=stream0)
del arg247_1
del arg248_1
buf517 = reinterpret_tensor(buf493, (2048, 512), (512, 1), 0); del buf493 # reuse
# Source Nodes: [x_327], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg249_1, buf517, 1048576, grid=grid(1048576), stream=stream0)
del arg249_1
buf518 = reinterpret_tensor(buf492, (4624, 2048), (2048, 1), 0); del buf492 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf516, (4624, 512), (512, 1), 0), reinterpret_tensor(buf517, (512, 2048), (1, 512), 0), out=buf518)
buf519 = reinterpret_tensor(buf518, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf518 # reuse
# Source Nodes: [x_328], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf519, arg250_1, 9469952, grid=grid(9469952), stream=stream0)
del arg250_1
buf520 = reinterpret_tensor(buf517, (512, 2048), (2048, 1), 0); del buf517 # reuse
# Source Nodes: [x_330], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg251_1, buf520, 1048576, grid=grid(1048576), stream=stream0)
del arg251_1
buf521 = reinterpret_tensor(buf516, (4624, 512), (512, 1), 0); del buf516 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf519, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf520, (2048, 512), (1, 2048), 0), out=buf521)
buf522 = reinterpret_tensor(buf521, (1, 4624, 512), (2367488, 512, 1), 0); del buf521 # reuse
buf523 = buf497; del buf497 # reuse
buf524 = buf496; del buf496 # reuse
# Source Nodes: [x_326, x_332, x_333], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf522, buf495, buf512, arg246_1, arg252_1, buf523, buf524, 4624, 512, grid=grid(4624), stream=stream0)
del arg246_1
del arg252_1
buf526 = buf469; del buf469 # reuse
# Source Nodes: [shifted_x_8, x_335], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf522, buf523, buf524, arg253_1, arg254_1, buf526, 2508800, grid=grid(2508800), stream=stream0)
del arg253_1
del arg254_1
buf527 = reinterpret_tensor(buf512, (100, 49, 512), (25088, 512, 1), 0); del buf512 # reuse
# Source Nodes: [linear_70], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf526, buf527, 2508800, grid=grid(2508800), stream=stream0)
buf528 = buf500; del buf500 # reuse
# Source Nodes: [linear_70], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg255_1, buf528, 786432, grid=grid(786432), stream=stream0)
del arg255_1
buf529 = buf501; del buf501 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf527, (4900, 512), (512, 1), 0), reinterpret_tensor(buf528, (512, 1536), (1, 512), 0), out=buf529)
buf530 = buf508; del buf508 # reuse
# Source Nodes: [attn_84, q_35], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf529, arg256_1, buf530, 2508800, grid=grid(2508800), stream=stream0)
buf531 = reinterpret_tensor(buf502, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf502 # reuse
# Source Nodes: [attn_84], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf529, arg256_1, buf531, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf532 = buf504; del buf504 # reuse
# Source Nodes: [attn_84], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf530, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf531, (1600, 32, 49), (1600, 49, 1), 0), out=buf532)
buf536 = buf507; del buf507 # reuse
# Source Nodes: [attn_86, attn_88, matmul_35], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf532, arg258_1, arg257_1, buf190, buf536, 78400, 49, grid=grid(78400), stream=stream0)
del arg257_1
del arg258_1
buf537 = reinterpret_tensor(buf531, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf531 # reuse
# Source Nodes: [matmul_35], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf529, arg256_1, buf537, 2508800, grid=grid(2508800), stream=stream0)
del arg256_1
buf538 = reinterpret_tensor(buf527, (1600, 49, 32), (1568, 32, 1), 0); del buf527 # reuse
# Source Nodes: [matmul_35], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf536, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf537, (1600, 49, 32), (1600, 32, 1), 0), out=buf538)
buf539 = buf510; del buf510 # reuse
# Source Nodes: [x_337], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf538, buf539, 2508800, grid=grid(2508800), stream=stream0)
buf540 = buf511; del buf511 # reuse
# Source Nodes: [x_338], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg259_1, buf540, 262144, grid=grid(262144), stream=stream0)
del arg259_1
buf541 = reinterpret_tensor(buf538, (4900, 512), (512, 1), 0); del buf538 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf539, (4900, 512), (512, 1), 0), reinterpret_tensor(buf540, (512, 512), (1, 512), 0), out=buf541)
buf546 = buf495; del buf495 # reuse
# Source Nodes: [layer_norm_40, x_345, x_346], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf522, buf541, arg260_1, arg261_1, arg262_1, buf546, 4624, 512, grid=grid(4624), stream=stream0)
del arg261_1
del arg262_1
buf547 = reinterpret_tensor(buf520, (2048, 512), (512, 1), 0); del buf520 # reuse
# Source Nodes: [x_346], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg263_1, buf547, 1048576, grid=grid(1048576), stream=stream0)
del arg263_1
buf548 = reinterpret_tensor(buf519, (4624, 2048), (2048, 1), 0); del buf519 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf546, (4624, 512), (512, 1), 0), reinterpret_tensor(buf547, (512, 2048), (1, 512), 0), out=buf548)
buf549 = reinterpret_tensor(buf548, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf548 # reuse
# Source Nodes: [x_347], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf549, arg264_1, 9469952, grid=grid(9469952), stream=stream0)
del arg264_1
buf550 = reinterpret_tensor(buf547, (512, 2048), (2048, 1), 0); del buf547 # reuse
# Source Nodes: [x_349], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg265_1, buf550, 1048576, grid=grid(1048576), stream=stream0)
del arg265_1
buf551 = reinterpret_tensor(buf546, (4624, 512), (512, 1), 0); del buf546 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf549, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf550, (2048, 512), (1, 2048), 0), out=buf551)
buf552 = reinterpret_tensor(buf551, (1, 4624, 512), (2367488, 512, 1), 0); del buf551 # reuse
buf553 = buf524; del buf524 # reuse
buf554 = buf523; del buf523 # reuse
# Source Nodes: [x_345, x_351, x_352], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf552, buf522, buf541, arg260_1, arg266_1, buf553, buf554, 4624, 512, grid=grid(4624), stream=stream0)
del arg260_1
del arg266_1
buf556 = reinterpret_tensor(buf541, (100, 49, 512), (25088, 512, 1), 0); del buf541 # reuse
# Source Nodes: [linear_74], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf552, buf553, buf554, arg267_1, arg268_1, buf556, 2508800, grid=grid(2508800), stream=stream0)
del arg267_1
del arg268_1
buf557 = buf528; del buf528 # reuse
# Source Nodes: [linear_74], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg269_1, buf557, 786432, grid=grid(786432), stream=stream0)
del arg269_1
buf558 = buf529; del buf529 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf556, (4900, 512), (512, 1), 0), reinterpret_tensor(buf557, (512, 1536), (1, 512), 0), out=buf558)
buf559 = buf537; del buf537 # reuse
# Source Nodes: [attn_90, q_37], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf558, arg270_1, buf559, 2508800, grid=grid(2508800), stream=stream0)
buf560 = reinterpret_tensor(buf530, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf530 # reuse
# Source Nodes: [attn_90], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf558, arg270_1, buf560, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf561 = buf532; del buf532 # reuse
# Source Nodes: [attn_90], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf559, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf560, (1600, 32, 49), (1600, 49, 1), 0), out=buf561)
buf564 = buf536; del buf536 # reuse
# Source Nodes: [attn_91, attn_92, matmul_37], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf561, arg272_1, arg271_1, buf564, 78400, 49, grid=grid(78400), stream=stream0)
del arg271_1
del arg272_1
buf565 = reinterpret_tensor(buf560, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf560 # reuse
# Source Nodes: [matmul_37], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf558, arg270_1, buf565, 2508800, grid=grid(2508800), stream=stream0)
del arg270_1
buf566 = reinterpret_tensor(buf556, (1600, 49, 32), (1568, 32, 1), 0); del buf556 # reuse
# Source Nodes: [matmul_37], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf564, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf565, (1600, 49, 32), (1600, 32, 1), 0), out=buf566)
buf567 = buf539; del buf539 # reuse
# Source Nodes: [x_356], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf566, buf567, 2508800, grid=grid(2508800), stream=stream0)
buf568 = buf540; del buf540 # reuse
# Source Nodes: [x_357], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg273_1, buf568, 262144, grid=grid(262144), stream=stream0)
del arg273_1
buf569 = reinterpret_tensor(buf566, (4900, 512), (512, 1), 0); del buf566 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf567, (4900, 512), (512, 1), 0), reinterpret_tensor(buf568, (512, 512), (1, 512), 0), out=buf569)
buf573 = buf522; del buf522 # reuse
# Source Nodes: [layer_norm_42, x_363, x_364], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf552, buf569, arg274_1, arg275_1, arg276_1, buf573, 4624, 512, grid=grid(4624), stream=stream0)
del arg275_1
del arg276_1
buf574 = reinterpret_tensor(buf550, (2048, 512), (512, 1), 0); del buf550 # reuse
# Source Nodes: [x_364], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg277_1, buf574, 1048576, grid=grid(1048576), stream=stream0)
del arg277_1
buf575 = reinterpret_tensor(buf549, (4624, 2048), (2048, 1), 0); del buf549 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf573, (4624, 512), (512, 1), 0), reinterpret_tensor(buf574, (512, 2048), (1, 512), 0), out=buf575)
buf576 = reinterpret_tensor(buf575, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf575 # reuse
# Source Nodes: [x_365], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf576, arg278_1, 9469952, grid=grid(9469952), stream=stream0)
del arg278_1
buf577 = reinterpret_tensor(buf574, (512, 2048), (2048, 1), 0); del buf574 # reuse
# Source Nodes: [x_367], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg279_1, buf577, 1048576, grid=grid(1048576), stream=stream0)
del arg279_1
buf578 = reinterpret_tensor(buf573, (4624, 512), (512, 1), 0); del buf573 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf576, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf577, (2048, 512), (1, 2048), 0), out=buf578)
buf579 = reinterpret_tensor(buf578, (1, 4624, 512), (2367488, 512, 1), 0); del buf578 # reuse
buf580 = buf554; del buf554 # reuse
buf581 = buf553; del buf553 # reuse
# Source Nodes: [x_363, x_369, x_370], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf579, buf552, buf569, arg274_1, arg280_1, buf580, buf581, 4624, 512, grid=grid(4624), stream=stream0)
del arg274_1
del arg280_1
buf583 = buf526; del buf526 # reuse
# Source Nodes: [shifted_x_9, x_372], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf579, buf580, buf581, arg281_1, arg282_1, buf583, 2508800, grid=grid(2508800), stream=stream0)
del arg281_1
del arg282_1
buf584 = reinterpret_tensor(buf569, (100, 49, 512), (25088, 512, 1), 0); del buf569 # reuse
# Source Nodes: [linear_78], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf583, buf584, 2508800, grid=grid(2508800), stream=stream0)
buf585 = buf557; del buf557 # reuse
# Source Nodes: [linear_78], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg283_1, buf585, 786432, grid=grid(786432), stream=stream0)
del arg283_1
buf586 = buf558; del buf558 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf584, (4900, 512), (512, 1), 0), reinterpret_tensor(buf585, (512, 1536), (1, 512), 0), out=buf586)
buf587 = buf565; del buf565 # reuse
# Source Nodes: [attn_94, q_39], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf586, arg284_1, buf587, 2508800, grid=grid(2508800), stream=stream0)
buf588 = reinterpret_tensor(buf559, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf559 # reuse
# Source Nodes: [attn_94], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf586, arg284_1, buf588, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf589 = buf561; del buf561 # reuse
# Source Nodes: [attn_94], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf587, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf588, (1600, 32, 49), (1600, 49, 1), 0), out=buf589)
buf593 = buf564; del buf564 # reuse
# Source Nodes: [attn_96, attn_98, matmul_39], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf589, arg286_1, arg285_1, buf190, buf593, 78400, 49, grid=grid(78400), stream=stream0)
del arg285_1
del arg286_1
buf594 = reinterpret_tensor(buf588, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf588 # reuse
# Source Nodes: [matmul_39], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf586, arg284_1, buf594, 2508800, grid=grid(2508800), stream=stream0)
del arg284_1
buf595 = reinterpret_tensor(buf584, (1600, 49, 32), (1568, 32, 1), 0); del buf584 # reuse
# Source Nodes: [matmul_39], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf593, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf594, (1600, 49, 32), (1600, 32, 1), 0), out=buf595)
buf596 = buf567; del buf567 # reuse
# Source Nodes: [x_374], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf595, buf596, 2508800, grid=grid(2508800), stream=stream0)
buf597 = buf568; del buf568 # reuse
# Source Nodes: [x_375], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg287_1, buf597, 262144, grid=grid(262144), stream=stream0)
del arg287_1
buf598 = reinterpret_tensor(buf595, (4900, 512), (512, 1), 0); del buf595 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf596, (4900, 512), (512, 1), 0), reinterpret_tensor(buf597, (512, 512), (1, 512), 0), out=buf598)
buf603 = buf552; del buf552 # reuse
# Source Nodes: [layer_norm_44, x_382, x_383], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf579, buf598, arg288_1, arg289_1, arg290_1, buf603, 4624, 512, grid=grid(4624), stream=stream0)
del arg289_1
del arg290_1
buf604 = reinterpret_tensor(buf577, (2048, 512), (512, 1), 0); del buf577 # reuse
# Source Nodes: [x_383], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg291_1, buf604, 1048576, grid=grid(1048576), stream=stream0)
del arg291_1
buf605 = reinterpret_tensor(buf576, (4624, 2048), (2048, 1), 0); del buf576 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf603, (4624, 512), (512, 1), 0), reinterpret_tensor(buf604, (512, 2048), (1, 512), 0), out=buf605)
buf606 = reinterpret_tensor(buf605, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf605 # reuse
# Source Nodes: [x_384], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf606, arg292_1, 9469952, grid=grid(9469952), stream=stream0)
del arg292_1
buf607 = reinterpret_tensor(buf604, (512, 2048), (2048, 1), 0); del buf604 # reuse
# Source Nodes: [x_386], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg293_1, buf607, 1048576, grid=grid(1048576), stream=stream0)
del arg293_1
buf608 = reinterpret_tensor(buf603, (4624, 512), (512, 1), 0); del buf603 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf606, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf607, (2048, 512), (1, 2048), 0), out=buf608)
buf609 = reinterpret_tensor(buf608, (1, 4624, 512), (2367488, 512, 1), 0); del buf608 # reuse
buf610 = buf581; del buf581 # reuse
buf611 = buf580; del buf580 # reuse
# Source Nodes: [x_382, x_388, x_389], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_69.run(buf609, buf579, buf598, arg288_1, arg294_1, buf610, buf611, 4624, 512, grid=grid(4624), stream=stream0)
del arg288_1
del arg294_1
buf613 = reinterpret_tensor(buf598, (100, 49, 512), (25088, 512, 1), 0); del buf598 # reuse
# Source Nodes: [linear_82], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_51.run(buf609, buf610, buf611, arg295_1, arg296_1, buf613, 2508800, grid=grid(2508800), stream=stream0)
del arg295_1
del arg296_1
buf614 = buf585; del buf585 # reuse
# Source Nodes: [linear_82], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg297_1, buf614, 786432, grid=grid(786432), stream=stream0)
del arg297_1
buf615 = buf586; del buf586 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf613, (4900, 512), (512, 1), 0), reinterpret_tensor(buf614, (512, 1536), (1, 512), 0), out=buf615)
buf616 = buf594; del buf594 # reuse
# Source Nodes: [attn_100, q_41], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf615, arg298_1, buf616, 2508800, grid=grid(2508800), stream=stream0)
buf617 = reinterpret_tensor(buf587, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf587 # reuse
# Source Nodes: [attn_100], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf615, arg298_1, buf617, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf618 = buf589; del buf589 # reuse
# Source Nodes: [attn_100], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf616, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf617, (1600, 32, 49), (1600, 49, 1), 0), out=buf618)
buf621 = buf593; del buf593 # reuse
# Source Nodes: [attn_101, attn_102, matmul_41], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_55.run(buf618, arg300_1, arg299_1, buf621, 78400, 49, grid=grid(78400), stream=stream0)
del arg299_1
del arg300_1
buf622 = reinterpret_tensor(buf617, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf617 # reuse
# Source Nodes: [matmul_41], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf615, arg298_1, buf622, 2508800, grid=grid(2508800), stream=stream0)
del arg298_1
buf623 = reinterpret_tensor(buf613, (1600, 49, 32), (1568, 32, 1), 0); del buf613 # reuse
# Source Nodes: [matmul_41], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf621, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf622, (1600, 49, 32), (1600, 32, 1), 0), out=buf623)
buf624 = buf596; del buf596 # reuse
# Source Nodes: [x_393], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf623, buf624, 2508800, grid=grid(2508800), stream=stream0)
buf625 = buf597; del buf597 # reuse
# Source Nodes: [x_394], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg301_1, buf625, 262144, grid=grid(262144), stream=stream0)
del arg301_1
buf626 = reinterpret_tensor(buf623, (4900, 512), (512, 1), 0); del buf623 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf624, (4900, 512), (512, 1), 0), reinterpret_tensor(buf625, (512, 512), (1, 512), 0), out=buf626)
buf630 = buf579; del buf579 # reuse
# Source Nodes: [layer_norm_46, x_400, x_401], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_58.run(buf609, buf626, arg302_1, arg303_1, arg304_1, buf630, 4624, 512, grid=grid(4624), stream=stream0)
del arg303_1
del arg304_1
buf631 = reinterpret_tensor(buf607, (2048, 512), (512, 1), 0); del buf607 # reuse
# Source Nodes: [x_401], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg305_1, buf631, 1048576, grid=grid(1048576), stream=stream0)
del arg305_1
buf632 = reinterpret_tensor(buf606, (4624, 2048), (2048, 1), 0); del buf606 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf630, (4624, 512), (512, 1), 0), reinterpret_tensor(buf631, (512, 2048), (1, 512), 0), out=buf632)
buf633 = reinterpret_tensor(buf632, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf632 # reuse
# Source Nodes: [x_402], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf633, arg306_1, 9469952, grid=grid(9469952), stream=stream0)
del arg306_1
buf634 = reinterpret_tensor(buf631, (512, 2048), (2048, 1), 0); del buf631 # reuse
# Source Nodes: [x_404], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg307_1, buf634, 1048576, grid=grid(1048576), stream=stream0)
del arg307_1
buf635 = reinterpret_tensor(buf630, (4624, 512), (512, 1), 0); del buf630 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf633, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf634, (2048, 512), (1, 2048), 0), out=buf635)
buf636 = reinterpret_tensor(buf635, (1, 4624, 512), (2367488, 512, 1), 0); del buf635 # reuse
buf637 = buf611; del buf611 # reuse
buf638 = buf610; del buf610 # reuse
# Source Nodes: [x_400, x_406, x_407], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_61.run(buf636, buf609, buf626, arg302_1, arg308_1, buf637, buf638, 4624, 512, grid=grid(4624), stream=stream0)
del arg302_1
del arg308_1
buf640 = buf583; del buf583 # reuse
# Source Nodes: [shifted_x_10, x_409], Original ATen: [aten.constant_pad_nd, aten.roll]
triton_poi_fused_constant_pad_nd_roll_62.run(buf636, buf637, buf638, arg309_1, arg310_1, buf640, 2508800, grid=grid(2508800), stream=stream0)
del arg309_1
del arg310_1
buf641 = reinterpret_tensor(buf626, (100, 49, 512), (25088, 512, 1), 0); del buf626 # reuse
# Source Nodes: [linear_86], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_63.run(buf640, buf641, 2508800, grid=grid(2508800), stream=stream0)
del buf640
buf642 = buf614; del buf614 # reuse
# Source Nodes: [linear_86], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_52.run(arg311_1, buf642, 786432, grid=grid(786432), stream=stream0)
del arg311_1
buf643 = buf615; del buf615 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf641, (4900, 512), (512, 1), 0), reinterpret_tensor(buf642, (512, 1536), (1, 512), 0), out=buf643)
del buf642
buf644 = buf622; del buf622 # reuse
# Source Nodes: [attn_104, q_43], Original ATen: [aten.clone, aten.mul]
triton_poi_fused_clone_mul_53.run(buf643, arg312_1, buf644, 2508800, grid=grid(2508800), stream=stream0)
buf645 = reinterpret_tensor(buf616, (100, 16, 32, 49), (25600, 1600, 49, 1), 0); del buf616 # reuse
# Source Nodes: [attn_104], Original ATen: [aten.clone]
triton_poi_fused_clone_54.run(buf643, arg312_1, buf645, 51200, 49, grid=grid(51200, 49), stream=stream0)
buf646 = buf618; del buf618 # reuse
# Source Nodes: [attn_104], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf644, (1600, 49, 32), (1600, 32, 1), 0), reinterpret_tensor(buf645, (1600, 32, 49), (1600, 49, 1), 0), out=buf646)
del buf644
buf650 = buf621; del buf621 # reuse
# Source Nodes: [attn_106, attn_108, matmul_43], Original ATen: [aten._softmax, aten._to_copy, aten.add]
triton_per_fused__softmax__to_copy_add_67.run(buf646, arg314_1, arg313_1, buf190, buf650, 78400, 49, grid=grid(78400), stream=stream0)
del arg313_1
del arg314_1
del buf190
del buf646
buf651 = reinterpret_tensor(buf645, (100, 16, 49, 32), (25600, 1600, 32, 1), 0); del buf645 # reuse
# Source Nodes: [matmul_43], Original ATen: [aten.clone]
triton_poi_fused_clone_56.run(buf643, arg312_1, buf651, 2508800, grid=grid(2508800), stream=stream0)
del arg312_1
del buf643
buf652 = reinterpret_tensor(buf641, (1600, 49, 32), (1568, 32, 1), 0); del buf641 # reuse
# Source Nodes: [matmul_43], Original ATen: [aten.bmm]
extern_kernels.bmm(reinterpret_tensor(buf650, (1600, 49, 49), (2432, 49, 1), 0), reinterpret_tensor(buf651, (1600, 49, 32), (1600, 32, 1), 0), out=buf652)
del buf650
del buf651
buf653 = buf624; del buf624 # reuse
# Source Nodes: [x_411], Original ATen: [aten.clone]
triton_poi_fused_clone_57.run(buf652, buf653, 2508800, grid=grid(2508800), stream=stream0)
buf654 = buf625; del buf625 # reuse
# Source Nodes: [x_412], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_37.run(arg315_1, buf654, 262144, grid=grid(262144), stream=stream0)
del arg315_1
buf655 = reinterpret_tensor(buf652, (4900, 512), (512, 1), 0); del buf652 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf653, (4900, 512), (512, 1), 0), reinterpret_tensor(buf654, (512, 512), (1, 512), 0), out=buf655)
del buf653
del buf654
buf660 = buf609; del buf609 # reuse
# Source Nodes: [layer_norm_48, x_419, x_420], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_red_fused__to_copy_add_native_layer_norm_68.run(buf636, buf655, arg316_1, arg317_1, arg318_1, buf660, 4624, 512, grid=grid(4624), stream=stream0)
del arg317_1
del arg318_1
buf661 = reinterpret_tensor(buf634, (2048, 512), (512, 1), 0); del buf634 # reuse
# Source Nodes: [x_420], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg319_1, buf661, 1048576, grid=grid(1048576), stream=stream0)
del arg319_1
buf662 = reinterpret_tensor(buf633, (4624, 2048), (2048, 1), 0); del buf633 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf660, (4624, 512), (512, 1), 0), reinterpret_tensor(buf661, (512, 2048), (1, 512), 0), out=buf662)
buf663 = reinterpret_tensor(buf662, (1, 4624, 2048), (9469952, 2048, 1), 0); del buf662 # reuse
# Source Nodes: [x_421], Original ATen: [aten.gelu]
triton_poi_fused_gelu_60.run(buf663, arg320_1, 9469952, grid=grid(9469952), stream=stream0)
del arg320_1
buf664 = reinterpret_tensor(buf661, (512, 2048), (2048, 1), 0); del buf661 # reuse
# Source Nodes: [x_423], Original ATen: [aten._to_copy]
triton_poi_fused__to_copy_59.run(arg321_1, buf664, 1048576, grid=grid(1048576), stream=stream0)
del arg321_1
buf665 = reinterpret_tensor(buf660, (4624, 512), (512, 1), 0); del buf660 # reuse
# Source Nodes: [], Original ATen: []
extern_kernels.mm(reinterpret_tensor(buf663, (4624, 2048), (2048, 1), 0), reinterpret_tensor(buf664, (2048, 512), (1, 2048), 0), out=buf665)
del buf663
del buf664
buf666 = empty_strided_cuda((1, 4624, 512), (2367488, 512, 1), torch.float32)
buf667 = buf638; del buf638 # reuse
buf668 = buf637; del buf637 # reuse
# Source Nodes: [x_419, x_425, x_out_2], Original ATen: [aten._to_copy, aten.add, aten.native_layer_norm]
triton_per_fused__to_copy_add_native_layer_norm_70.run(buf636, buf655, arg316_1, buf665, arg322_1, buf666, buf667, buf668, 4624, 512, grid=grid(4624), stream=stream0)
del arg316_1
del arg322_1
del buf636
del buf655
del buf665
buf670 = reinterpret_tensor(buf34, (1, 128, 272, 272), (9469952, 73984, 272, 1), 0); del buf34 # reuse
# Source Nodes: [out], Original ATen: [aten.clone]
triton_poi_fused_clone_71.run(buf68, buf72, buf73, arg36_1, arg37_1, buf670, 128, 73984, grid=grid(128, 73984), stream=stream0)
del arg36_1
del arg37_1
del buf68
del buf72
del buf73
buf671 = empty_strided_cuda((1, 256, 136, 136), (4734976, 18496, 136, 1), torch.float32)
# Source Nodes: [out_1], Original ATen: [aten.clone]
triton_poi_fused_clone_72.run(buf139, buf143, buf144, arg69_1, arg70_1, buf671, 256, 18496, grid=grid(256, 18496), stream=stream0)
del arg69_1
del arg70_1
del buf139
del buf143
del buf144
buf672 = empty_strided_cuda((1, 512, 68, 68), (2367488, 4624, 68, 1), torch.float32)
# Source Nodes: [out_2], Original ATen: [aten.clone]
triton_poi_fused_clone_73.run(buf666, buf667, buf668, arg323_1, arg324_1, buf672, 512, 4624, grid=grid(512, 4624), stream=stream0)
del arg323_1
del arg324_1
del buf666
del buf667
del buf668
return (buf670, buf671, buf672, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((1, 3, 1088, 1088), (3551232, 1183744, 1088, 1), device='cuda:0', dtype=torch.float32)
arg1_1 = rand_strided((128, 3, 4, 4), (48, 16, 4, 1), device='cuda:0', dtype=torch.float32)
arg2_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg3_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg4_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg5_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg6_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg7_1 = rand_strided((384, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg8_1 = rand_strided((384, ), (1, ), device='cuda:0', dtype=torch.float32)
arg9_1 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32)
arg10_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg11_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg12_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg13_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg14_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg15_1 = rand_strided((512, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg16_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg17_1 = rand_strided((128, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg18_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg19_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg20_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg21_1 = rand_strided((384, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg22_1 = rand_strided((384, ), (1, ), device='cuda:0', dtype=torch.float32)
arg23_1 = rand_strided((169, 4), (4, 1), device='cuda:0', dtype=torch.float32)
arg24_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg25_1 = rand_strided((128, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg26_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg27_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg28_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg29_1 = rand_strided((512, 128), (128, 1), device='cuda:0', dtype=torch.float32)
arg30_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg31_1 = rand_strided((128, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg32_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg33_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg34_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg35_1 = rand_strided((256, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg36_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg37_1 = rand_strided((128, ), (1, ), device='cuda:0', dtype=torch.float32)
arg38_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg39_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg40_1 = rand_strided((768, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg41_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
arg42_1 = rand_strided((169, 8), (8, 1), device='cuda:0', dtype=torch.float32)
arg43_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg44_1 = rand_strided((256, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg45_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg46_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg47_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg48_1 = rand_strided((1024, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg49_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
arg50_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
arg51_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg52_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg53_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg54_1 = rand_strided((768, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg55_1 = rand_strided((768, ), (1, ), device='cuda:0', dtype=torch.float32)
arg56_1 = rand_strided((169, 8), (8, 1), device='cuda:0', dtype=torch.float32)
arg57_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg58_1 = rand_strided((256, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg59_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg60_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg61_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg62_1 = rand_strided((1024, 256), (256, 1), device='cuda:0', dtype=torch.float32)
arg63_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
arg64_1 = rand_strided((256, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
arg65_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg66_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
arg67_1 = rand_strided((1024, ), (1, ), device='cuda:0', dtype=torch.float32)
arg68_1 = rand_strided((512, 1024), (1024, 1), device='cuda:0', dtype=torch.float32)
arg69_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg70_1 = rand_strided((256, ), (1, ), device='cuda:0', dtype=torch.float32)
arg71_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg72_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg73_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg74_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg75_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg76_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg77_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg78_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg79_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg80_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg81_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg82_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg83_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg84_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg85_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg86_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg87_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg88_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg89_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg90_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg91_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg92_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg93_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg94_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg95_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg96_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg97_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg98_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg99_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg100_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg101_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg102_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg103_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg104_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg105_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg106_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg107_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg108_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg109_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg110_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg111_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg112_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg113_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg114_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg115_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg116_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg117_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg118_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg119_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg120_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg121_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg122_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg123_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg124_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg125_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg126_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg127_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg128_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg129_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg130_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg131_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg132_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg133_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg134_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg135_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg136_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg137_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg138_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg139_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg140_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg141_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg142_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg143_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg144_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg145_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg146_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg147_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg148_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg149_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg150_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg151_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg152_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg153_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg154_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg155_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg156_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg157_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg158_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg159_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg160_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg161_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg162_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg163_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg164_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg165_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg166_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg167_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg168_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg169_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg170_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg171_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg172_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg173_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg174_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg175_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg176_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg177_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg178_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg179_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg180_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg181_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg182_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg183_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg184_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg185_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg186_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg187_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg188_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg189_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg190_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg191_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg192_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg193_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg194_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg195_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg196_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg197_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg198_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg199_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg200_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg201_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg202_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg203_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg204_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg205_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg206_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg207_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg208_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg209_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg210_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg211_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg212_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg213_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg214_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg215_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg216_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg217_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg218_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg219_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg220_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg221_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg222_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg223_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg224_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg225_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg226_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg227_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg228_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg229_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg230_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg231_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg232_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg233_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg234_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg235_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg236_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg237_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg238_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg239_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg240_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg241_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg242_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg243_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg244_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg245_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg246_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg247_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg248_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg249_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg250_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg251_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg252_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg253_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg254_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg255_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg256_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg257_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg258_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg259_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg260_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg261_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg262_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg263_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg264_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg265_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg266_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg267_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg268_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg269_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg270_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg271_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg272_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg273_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg274_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg275_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg276_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg277_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg278_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg279_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg280_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg281_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg282_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg283_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg284_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg285_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg286_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg287_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg288_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg289_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg290_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg291_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg292_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg293_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg294_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg295_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg296_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg297_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg298_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg299_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg300_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg301_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg302_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg303_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg304_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg305_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg306_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg307_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg308_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg309_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg310_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg311_1 = rand_strided((1536, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg312_1 = rand_strided((1536, ), (1, ), device='cuda:0', dtype=torch.float32)
arg313_1 = rand_strided((169, 16), (16, 1), device='cuda:0', dtype=torch.float32)
arg314_1 = rand_strided((49, 49), (49, 1), device='cuda:0', dtype=torch.int64)
arg315_1 = rand_strided((512, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg316_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg317_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg318_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg319_1 = rand_strided((2048, 512), (512, 1), device='cuda:0', dtype=torch.float32)
arg320_1 = rand_strided((2048, ), (1, ), device='cuda:0', dtype=torch.float32)
arg321_1 = rand_strided((512, 2048), (2048, 1), device='cuda:0', dtype=torch.float32)
arg322_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg323_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
arg324_1 = rand_strided((512, ), (1, ), device='cuda:0', dtype=torch.float32)
fn = lambda: call([arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1, arg60_1, arg61_1, arg62_1, arg63_1, arg64_1, arg65_1, arg66_1, arg67_1, arg68_1, arg69_1, arg70_1, arg71_1, arg72_1, arg73_1, arg74_1, arg75_1, arg76_1, arg77_1, arg78_1, arg79_1, arg80_1, arg81_1, arg82_1, arg83_1, arg84_1, arg85_1, arg86_1, arg87_1, arg88_1, arg89_1, arg90_1, arg91_1, arg92_1, arg93_1, arg94_1, arg95_1, arg96_1, arg97_1, arg98_1, arg99_1, arg100_1, arg101_1, arg102_1, arg103_1, arg104_1, arg105_1, arg106_1, arg107_1, arg108_1, arg109_1, arg110_1, arg111_1, arg112_1, arg113_1, arg114_1, arg115_1, arg116_1, arg117_1, arg118_1, arg119_1, arg120_1, arg121_1, arg122_1, arg123_1, arg124_1, arg125_1, arg126_1, arg127_1, arg128_1, arg129_1, arg130_1, arg131_1, arg132_1, arg133_1, arg134_1, arg135_1, arg136_1, arg137_1, arg138_1, arg139_1, arg140_1, arg141_1, arg142_1, arg143_1, arg144_1, arg145_1, arg146_1, arg147_1, arg148_1, arg149_1, arg150_1, arg151_1, arg152_1, arg153_1, arg154_1, arg155_1, arg156_1, arg157_1, arg158_1, arg159_1, arg160_1, arg161_1, arg162_1, arg163_1, arg164_1, arg165_1, arg166_1, arg167_1, arg168_1, arg169_1, arg170_1, arg171_1, arg172_1, arg173_1, arg174_1, arg175_1, arg176_1, arg177_1, arg178_1, arg179_1, arg180_1, arg181_1, arg182_1, arg183_1, arg184_1, arg185_1, arg186_1, arg187_1, arg188_1, arg189_1, arg190_1, arg191_1, arg192_1, arg193_1, arg194_1, arg195_1, arg196_1, arg197_1, arg198_1, arg199_1, arg200_1, arg201_1, arg202_1, arg203_1, arg204_1, arg205_1, arg206_1, arg207_1, arg208_1, arg209_1, arg210_1, arg211_1, arg212_1, arg213_1, arg214_1, arg215_1, arg216_1, arg217_1, arg218_1, arg219_1, arg220_1, arg221_1, arg222_1, arg223_1, arg224_1, arg225_1, arg226_1, arg227_1, arg228_1, arg229_1, arg230_1, arg231_1, arg232_1, arg233_1, arg234_1, arg235_1, arg236_1, arg237_1, arg238_1, arg239_1, arg240_1, arg241_1, arg242_1, arg243_1, arg244_1, arg245_1, arg246_1, arg247_1, arg248_1, arg249_1, arg250_1, arg251_1, arg252_1, arg253_1, arg254_1, arg255_1, arg256_1, arg257_1, arg258_1, arg259_1, arg260_1, arg261_1, arg262_1, arg263_1, arg264_1, arg265_1, arg266_1, arg267_1, arg268_1, arg269_1, arg270_1, arg271_1, arg272_1, arg273_1, arg274_1, arg275_1, arg276_1, arg277_1, arg278_1, arg279_1, arg280_1, arg281_1, arg282_1, arg283_1, arg284_1, arg285_1, arg286_1, arg287_1, arg288_1, arg289_1, arg290_1, arg291_1, arg292_1, arg293_1, arg294_1, arg295_1, arg296_1, arg297_1, arg298_1, arg299_1, arg300_1, arg301_1, arg302_1, arg303_1, arg304_1, arg305_1, arg306_1, arg307_1, arg308_1, arg309_1, arg310_1, arg311_1, arg312_1, arg313_1, arg314_1, arg315_1, arg316_1, arg317_1, arg318_1, arg319_1, arg320_1, arg321_1, arg322_1, arg323_1, arg324_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)