[Mlir-commits] [mlir] c3579f0 - [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 8 09:23:12 PST 2025
Author: Sang Ik Lee
Date: 2025-12-08T09:23:07-08:00
New Revision: c3579f01996a0ac018b9d452b8ae1cc2cac36acf
URL: https://github.com/llvm/llvm-project/commit/c3579f01996a0ac018b9d452b8ae1cc2cac36acf
DIFF: https://github.com/llvm/llvm-project/commit/c3579f01996a0ac018b9d452b8ae1cc2cac36acf.diff
LOG: [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)
Some usage case or shapes for 2D block op with sub byte types can be
emulated with 2D block operations for non-sub byte types. Add sub byte
type i4 as a valid XeGPU type. And add lowering of certain 2D
block operations by emulating with larger element types.
Added:
mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 7f7f7d065c50e..716681fe9e187 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
include "mlir/IR/BuiltinTypes.td"
-def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
-def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
+def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16,
+ SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
+def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
def XeGPU_BaseAddrType
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9c99a24bea8cd..2b162ec3f3bf4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
- if (elemBitSize % 8 != 0)
+ bool isSubByte = elemBitSize < 8;
+ uint64_t wScaleFactor = 1;
+
+ if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ // For sub byte types, only 4bits are currently supported.
+ if (isSubByte) {
+ if (elemBitSize != 4)
+ return rewriter.notifyMatchFailure(
+ op, "Only sub byte types of 4bits are supported.");
+ if (tileRank != 2)
+ return rewriter.notifyMatchFailure(
+ op, "Sub byte types are only supported for 2D tensor descriptors.");
+ auto subByteFactor = 8 / elemBitSize;
+ auto tileH = tdescTy.getDimSize(0);
+ // Handle special case for packed load.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ if (op.getPacked().value_or(false)) {
+ // packed load is implemented as packed loads of 8bit elements.
+ if (tileH == systolicDepth * 4 &&
+ tileW == executionSize * subByteFactor) {
+ // Usage case for loading as Matrix B with pack request.
+ // source is assumed to pre-packed into 8bit elements
+ // Emulate with 8bit loads with pack request.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(8);
+ tileW = executionSize;
+ wScaleFactor = subByteFactor;
+ }
+ }
+ }
+ // If not handled by packed load case above, handle other cases.
+ if (wScaleFactor == 1) {
+ auto sub16BitFactor = subByteFactor * 2;
+ if (tileW == executionSize * sub16BitFactor) {
+ // Usage case for loading as Matrix A operand
+ // Emulate with 16bit loads/stores.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(16);
+ tileW = executionSize;
+ wScaleFactor = sub16BitFactor;
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported tile shape for sub byte types.");
+ }
+ }
+ // recompute element bit size for emulation.
+ elemBitSize = elemType.getIntOrFloatBitWidth();
+ }
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // FIXME: width or pitch is not the same as baseShapeW it should be the
+ // stride of the second to last dimension in row major layout.
// Compute width in bytes.
- Value baseWidthByte =
+ Value baseShapeWInBytes =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
// Compute pitch in bytes.
- Value basePitchByte =
+ Value basePitchBytes =
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
- // Get tile width from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(tileRank - 1);
+ if (wScaleFactor > 1) {
+ // Scale offsetW, baseShapeWInBytes for sub byte emulation.
+ // Note: tileW is already scaled above.
+ Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+ baseShapeWInBytes = arith::ShRSIOp::create(
+ rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
+ basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
+ wScaleFactorValLog2);
+ offsetW =
+ arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
+ }
// Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
// Get vblocks from the tensor descriptor type.
@@ -330,8 +398,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
- basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
@@ -339,8 +407,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
- basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
: rewriter.getIntegerType(elemBitSize));
Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
- basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
- vblocks, transpose, vnni,
+ rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
+ baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
+ tileH, vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..97e5ce14f8539
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+ // CHECK-LABEL: gpu.func @load_store_matrix_a
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
+ gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+ // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
+ // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
+ // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
+ %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+ // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
+ // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
+ // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
+ %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
+
+ // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+ // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
+ // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
+ // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
+ // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+ // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
+
+ // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+
+ // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
+ // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
+ // CHECK-SAME: %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
+ // CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
+ xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @load_matrix_b_request_pack
+ gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
+ %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
+
+ // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
+ // CHECK-SAME: cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
+ // CHECK-SAME: pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
+ // CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
+
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..f9254728bab41
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @prefetch_check {
+ // CHECK-LABEL: gpu.func @prefetch_matrix_a
+ gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
+ // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+ // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+ // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+ %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+ // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
+ // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+ // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
+ xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<8x64xi4>
+
+ gpu.return
+ }
+}
More information about the Mlir-commits
mailing list