for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1) for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4) if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) else:
s_local[i] = amax_local[i] * fp8_max_inv for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
) for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]: """
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default isNone.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`. """ assert x.is_contiguous(), "Input tensor must be contiguous" assert x.size(-1) % block_size == 0, (
f"Last dimension size must be divisible by block_size (block_size={block_size})"
)
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt isnotNone)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) return y, s
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K) for k in T.Pipelined(K_iters, num_stages=4): # Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared) # Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared) # Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k] for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True) # Promote to enable 2xAcc for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local) # TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication. """ assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" assert a_s.is_contiguous() and b_s.is_contiguous(), ( "Scaling factor tensors must be contiguous"
)
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) return c
@tilelang.jit(out_idx=[4], pass_configs=pass_configs) def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None: with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor: """
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
Die Informationen auf dieser Webseite wurden
nach bestem Wissen sorgfältig zusammengestellt. Es wird jedoch weder Vollständigkeit, noch Richtigkeit,
noch Qualität der bereit gestellten Informationen zugesichert.
Bemerkung:
Die farbliche Syntaxdarstellung ist noch experimentell.