[Mlir-commits] [mlir] c3579f0 - [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 8 09:23:12 PST 2025


Author: Sang Ik Lee
Date: 2025-12-08T09:23:07-08:00
New Revision: c3579f01996a0ac018b9d452b8ae1cc2cac36acf

URL: https://github.com/llvm/llvm-project/commit/c3579f01996a0ac018b9d452b8ae1cc2cac36acf
DIFF: https://github.com/llvm/llvm-project/commit/c3579f01996a0ac018b9d452b8ae1cc2cac36acf.diff

LOG: [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (#169099)

Some usage case or shapes for 2D block op with sub byte types can be
emulated with 2D block operations for non-sub byte types. Add sub byte
type i4 as a valid XeGPU type. And add lowering of certain 2D
block operations by emulating with larger element types.

Added: 
    mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
    mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 7f7f7d065c50e..716681fe9e187 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -13,8 +13,9 @@ include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
 include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
 include "mlir/IR/BuiltinTypes.td"
 
-def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
-def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
+def XeGPU_IntType : AnyTypeOf<[I1, I<4>, I8, I16, I32, I64, SI1, SI8, SI16,
+                               SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
+def XeGPU_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType

diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9c99a24bea8cd..2b162ec3f3bf4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -150,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
   }
 }
 
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
 class CreateNdDescToXeVMPattern
     : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -262,9 +270,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
           op, "Expected offset rank to match descriptor rank.");
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
-    if (elemBitSize % 8 != 0)
+    bool isSubByte = elemBitSize < 8;
+    uint64_t wScaleFactor = 1;
+
+    if (!isSubByte && (elemBitSize % 8 != 0))
       return rewriter.notifyMatchFailure(
           op, "Expected element type bit width to be multiple of 8.");
+    auto tileW = tdescTy.getDimSize(tileRank - 1);
+    // For sub byte types, only 4bits are currently supported.
+    if (isSubByte) {
+      if (elemBitSize != 4)
+        return rewriter.notifyMatchFailure(
+            op, "Only sub byte types of 4bits are supported.");
+      if (tileRank != 2)
+        return rewriter.notifyMatchFailure(
+            op, "Sub byte types are only supported for 2D tensor descriptors.");
+      auto subByteFactor = 8 / elemBitSize;
+      auto tileH = tdescTy.getDimSize(0);
+      // Handle special case for packed load.
+      if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+        if (op.getPacked().value_or(false)) {
+          // packed load is implemented as packed loads of 8bit elements.
+          if (tileH == systolicDepth * 4 &&
+              tileW == executionSize * subByteFactor) {
+            // Usage case for loading as Matrix B with pack request.
+            // source is assumed to pre-packed into 8bit elements
+            // Emulate with 8bit loads with pack request.
+            //   scaled_tileW = executionSize
+            elemType = rewriter.getIntegerType(8);
+            tileW = executionSize;
+            wScaleFactor = subByteFactor;
+          }
+        }
+      }
+      // If not handled by packed load case above, handle other cases.
+      if (wScaleFactor == 1) {
+        auto sub16BitFactor = subByteFactor * 2;
+        if (tileW == executionSize * sub16BitFactor) {
+          // Usage case for loading as Matrix A operand
+          // Emulate with 16bit loads/stores.
+          //   scaled_tileW = executionSize
+          elemType = rewriter.getIntegerType(16);
+          tileW = executionSize;
+          wScaleFactor = sub16BitFactor;
+        } else {
+          return rewriter.notifyMatchFailure(
+              op, "Unsupported tile shape for sub byte types.");
+        }
+      }
+      // recompute element bit size for emulation.
+      elemBitSize = elemType.getIntOrFloatBitWidth();
+    }
 
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
@@ -298,15 +354,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
       // Convert base pointer (i64) to LLVM pointer type.
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+      // FIXME: width or pitch is not the same as baseShapeW it should be the
+      // stride of the second to last dimension in row major layout.
       // Compute width in bytes.
-      Value baseWidthByte =
+      Value baseShapeWInBytes =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
       // Compute pitch in bytes.
-      Value basePitchByte =
+      Value basePitchBytes =
           arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
 
-      // Get tile width from the tensor descriptor type.
-      auto tileW = tdescTy.getDimSize(tileRank - 1);
+      if (wScaleFactor > 1) {
+        // Scale offsetW, baseShapeWInBytes for sub byte emulation.
+        // Note: tileW is already scaled above.
+        Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+            rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+        baseShapeWInBytes = arith::ShRSIOp::create(
+            rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
+        basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
+                                                wScaleFactorValLog2);
+        offsetW =
+            arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
+      }
       // Get tile height from the tensor descriptor type.
       auto tileH = tdescTy.getDimSize(0);
       // Get vblocks from the tensor descriptor type.
@@ -330,8 +398,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         auto storeCacheControl =
             translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         xevm::BlockStore2dOp::create(
-            rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
-            basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
+            rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+            basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
             xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
         rewriter.eraseOp(op);
       } else {
@@ -339,8 +407,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
             translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
           xevm::BlockPrefetch2dOp::create(
-              rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
-              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+              rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+              basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
               vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           rewriter.eraseOp(op);
         } else {
@@ -354,9 +422,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
                              : rewriter.getIntegerType(elemBitSize));
 
           Value resultFlatVec = xevm::BlockLoad2dOp::create(
-              rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
-              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
-              vblocks, transpose, vnni,
+              rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
+              baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
+              tileH, vblocks, transpose, vnni,
               xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           resultFlatVec = vector::BitCastOp::create(
               rewriter, loc,

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..97e5ce14f8539
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+    // CHECK-LABEL: gpu.func @load_store_matrix_a
+    // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xi4, 1>, %[[ARG1:.*]]: memref<16x128xi4, 1>
+    gpu.func @load_store_matrix_a(%src: memref<16x128xi4, 1>, %dst: memref<16x128xi4, 1>) kernel {
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+        // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
+        // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
+        // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
+        %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+        // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
+        // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
+        // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
+        %dstte = memref.memory_space_cast %dst : memref<16x128xi4, 1> to memref<16x128xi4>
+
+        // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+        // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD3_SRC]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
+        // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
+        // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
+        // CHECK-SAME:   cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        // CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+        %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x64xi4> -> vector<32xi4>
+
+        // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C128_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD3_DST]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
+        // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
+        // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
+        // CHECK-SAME:  %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
+        // CHECK-SAME:    cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:    tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
+        xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<32xi4>, !xegpu.tensor_desc<8x64xi4, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+
+    // CHECK-LABEL: gpu.func @load_matrix_b_request_pack
+    gpu.func @load_matrix_b_request_pack(%src: memref<64x128xi4, 1>, %dst: memref<64x128xi4, 1>) kernel {
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        %srcce = memref.memory_space_cast %src : memref<64x128xi4, 1> to memref<64x128xi4>
+        %dstte = memref.memory_space_cast %dst : memref<64x128xi4, 1> to memref<64x128xi4>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xi4> -> !xegpu.tensor_desc<32x32xi4>
+
+        // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
+        // CHECK-SAME:   cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
+        // CHECK-SAME:   pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
+        // CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<32x32xi4> -> vector<64xi4>
+
+        %c32 = arith.constant 32 : index
+        %c0 = arith.constant 0 : index
+        vector.store %loaded, %dstte[%c32, %c0] : memref<64x128xi4>, vector<64xi4>
+        gpu.return
+    }
+}

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..f9254728bab41
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @prefetch_check {
+    // CHECK-LABEL: gpu.func @prefetch_matrix_a
+    gpu.func @prefetch_matrix_a(%src: memref<16x128xi4, 1>) kernel {
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        %srcce = memref.memory_space_cast %src : memref<16x128xi4, 1> to memref<16x128xi4>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xi4> -> !xegpu.tensor_desc<8x64xi4>
+
+        // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
+        // CHECK-SAME:  <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:    tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
+        xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x64xi4>
+
+        gpu.return
+    }
+}


        


More information about the Mlir-commits mailing list