import tilelang from tilelang import language as T import torch torch.set_default_device("npu") torch.manual_seed(0) tilelang.disable_cache() pass_configs = { tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_COMBINE: True, tilelang.PassConfigKey.TL_ASCEND_AUTO_CV_SYNC: True, tilelang.PassConfigKey.TL_ASCEND_AUTO_SYNC: True, tilelang.PassConfigKey.TL_ASCEND_MEMORY_PLANNING: True, } @tilelang.jit(out_idx=[3], pass_configs=pass_configs) def sparse_attention_fwd( heads, dim, tail_dim, topk, kv_stride, kv_group=1, sm_scale=None, is_causal=True, block_I=64, ): assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal, "non-casual is not supported" assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" # NOTE: ascend only support exp interface instead of exp2 sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 if sm_scale is None else sm_scale batch = 1 # T.symbolic("batch") seq_len = 128 # T.symbolic("seq_len") seq_len_kv = 32768 # T.symbolic("seq_len_kv") head_kv = heads // kv_group q_shape = [batch, seq_len, heads, dim + tail_dim] kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] o_shape = [batch, seq_len, heads, dim] indices_shape = [batch, seq_len, kv_group, topk] # lse_shape = [batch, seq_len, heads] indices_dtype = "int32" dtype = "float16" accum_dtype = "float" H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: assert kv_group == 1, ( "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" ) BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim D_tail = tail_dim if head_kv > 64: assert head_kv % 64 == 0, "head_kv should be a multiple of 64" REPLICATE_H = head_kv // 64 else: REPLICATE_H = 1 H_per_block = padded_H if REPLICATE_H == 1 else 64 v_block = H_per_block block_num = seq_len * REPLICATE_H * batch * kv_group @T.prim_func def main( Q: T.Tensor(q_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore ): with T.Kernel(block_num, threads=2, is_npu=True) as (cid): bx = cid % (seq_len * REPLICATE_H) by = cid // (seq_len * REPLICATE_H) % batch bz = cid // (seq_len * REPLICATE_H) // batch % kv_group q_l1 = T.alloc_shared([H_per_block, D], dtype) q_tail_l1 = T.alloc_shared([H_per_block, D_tail], dtype) kv_l1 = T.alloc_shared([BI, D], dtype) kv_tail_l1 = T.alloc_shared([BI, D_tail], dtype) acc_s_l1 = T.alloc_shared([H_per_block, BI], dtype) acc_s_l0c = T.alloc_fragment([H_per_block, BI], accum_dtype) acc_o_l0c = T.alloc_fragment([H_per_block, D], accum_dtype) ## 2. Vector acc_o = T.alloc_shared([v_block, D], accum_dtype) sumexp = T.alloc_shared([v_block], accum_dtype) m_i = T.alloc_shared([v_block], accum_dtype) indices_ub_ = T.alloc_shared([BI], indices_dtype) kv_ub = T.alloc_shared([D], dtype) kv_tail_ub = T.alloc_shared([D_tail], dtype) acc_s_ub = T.alloc_shared([v_block, BI], accum_dtype) m_i_prev = T.alloc_shared([v_block], accum_dtype) acc_s_ub_ = T.alloc_shared([v_block, BI], accum_dtype) sumexp_i_ub = T.alloc_shared([v_block], accum_dtype) acc_s_half = T.alloc_shared([v_block, BI], dtype) acc_o_ub = T.alloc_shared([v_block, D], accum_dtype) acc_o_half = T.alloc_shared([v_block, D], dtype) b_i = by g_i = bz s_i = bx // REPLICATE_H H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) H1 = H0 + H_per_block T.copy(Q[b_i, s_i, H0:H1, :D], q_l1) T.copy(Q[b_i, s_i, H0:H1, D:], q_tail_l1) T.tile.fill(acc_o, 0.0) T.tile.fill(sumexp, 0.0) T.tile.fill(m_i, -(2.0**30)) for i_i in T.serial(NI): T.gemm_v0(q_l1, kv_l1, acc_s_l0c, transpose_B=True, init=True) T.gemm_v0(q_tail_l1, kv_tail_l1, acc_s_l0c, transpose_B=True) T.copy(acc_s_l0c, acc_s_ub_) T.gemm_v0(acc_s_l1, kv_l1, acc_o_l0c, init=True) T.copy(acc_o_l0c, acc_o_ub) T.copy(Indices[b_i, s_i, g_i, i_i * BI : i_i * BI + BI], indices_ub_) for bi_i in range(BI): T.copy(KV[b_i, indices_ub_[bi_i], g_i, :D], kv_ub) T.copy(KV[b_i, indices_ub_[bi_i], g_i, D:], kv_tail_ub) T.copy(kv_ub, kv_l1[bi_i, :]) T.copy(kv_tail_ub, kv_tail_l1[bi_i, :]) T.tile.fill(acc_s_ub, 0.0) T.copy(m_i, m_i_prev) for i, j in T.Parallel(v_block, BI): acc_s_ub[i, j] = acc_s_ub[i, j] + acc_s_ub_[i, j] for i, j in T.Parallel(v_block, BI): acc_s_ub[i, j] = acc_s_ub[i, j] * sm_scale T.reduce_max(acc_s_ub, m_i, dim=-1) for i in T.Parallel(v_block): m_i[i] = T.max(m_i[i], m_i_prev[i]) m_i_prev[i] = m_i_prev[i] - m_i[i] m_i_prev[i] = T.exp(m_i_prev[i]) for h_i, j in T.Parallel(v_block, BI): acc_s_ub[h_i, j] = acc_s_ub[h_i, j] - m_i[h_i] acc_s_ub[h_i, j] = T.exp(acc_s_ub[h_i, j]) T.reduce_sum(acc_s_ub, sumexp_i_ub, dim=-1) for i in T.Parallel(v_block): sumexp[i] *= m_i_prev[i] sumexp[i] += sumexp_i_ub[i] for h_i, j in T.Parallel(v_block, D): acc_o[h_i, j] = acc_o[h_i, j] * m_i_prev[h_i] T.copy(acc_s_ub, acc_s_half) T.copy(acc_s_half, acc_s_l1) for i, j in T.Parallel(v_block, D): acc_o[i, j] += acc_o_ub[i, j] for h_i, j in T.Parallel(v_block, D): acc_o[h_i, j] = acc_o[h_i, j] / sumexp[h_i] T.copy(acc_o, acc_o_half) T.copy(acc_o_half, Output[b_i, s_i, H0 : H0 + v_block, :]) return main func = sparse_attention_fwd( heads=128, dim=512, tail_dim=64, topk=2048, kv_stride=1, ) def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) b, sq, h, dim_q = q.shape b, sk, g, _ = kv.shape if q_start_index_s is None: q_start_index_s = sk * kv_stride - sq assert kv.shape[-1] == 576, "you should assign dim otherwise" dim = 512 k = kv v = kv[..., :dim] b, _, _, dim_v = v.shape # num_kv_per_index = 1 g_index = g h_index = h // g compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32).view(-1, 1) >= torch.arange( kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32 ).view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) mask[:, :, : kv_stride - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) score = torch.einsum("bmghd,bngd->bghmn", q, k) sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) p = score.softmax(dim=-1) p = p.view(b, g_index, h_index, -1, sq, sk) p = p.view(b, g, -1, sq, sk) o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) o = o.reshape(b, sq, h, dim_v) return o.to(torch.float16) B, S, SKV, H, HKV, DQK, DV, topk = 1, 128, 32768, 128, 1, 576, 512, 2048 dtype = torch.float16 KV_stride = 1 q_start_s_index = 4096 * 7 q = torch.randn((B, S, H, DQK), dtype=dtype) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype) indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32) for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, ((t + q_start_s_index) // KV_stride)))[:topk] indices[b, t, h, : len(i_i)] = i_i torch.npu.synchronize() print("init successful!") output = func(q, kv, indices) torch.npu.synchronize() ref_output = ref_sparse_attention_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) torch.npu.synchronize() torch.testing.assert_close(ref_output, output, rtol=1e-2, atol=1e-2) print("Test Passed!")