[Mlir-commits] [mlir] [MLIR][XeGPU][Conversion] Add 2D block op support for sub byte types (PR #169099)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 21 13:32:26 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

Some usage case or shapes for 2D block op with sub byte types can be emulated with 2D block operations for non sub byte types. Add sub byte type F4E2M1FN as valid XeGPU float type. And add lowering of certain 2D block operations by emulating with larger element types.

---
Full diff: https://github.com/llvm/llvm-project/pull/169099.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1) 
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+78-3) 
- (added) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir (+115) 
- (added) mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir (+39) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 7f7f7d065c50e..da577824ce114 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -14,7 +14,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
 include "mlir/IR/BuiltinTypes.td"
 
 def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
-def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
+def XeGPU_FloatType : AnyTypeOf<[F4E2M1FN, F16, F32, F64, BF16, TF32]>;
 def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
 def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
 def XeGPU_BaseAddrType
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..7f95501958955 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -151,6 +151,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
   }
 }
 
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
 class CreateNdDescToXeVMPattern
     : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -268,9 +276,64 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
           op, "Expected offset rank to match descriptor rank.");
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
-    if (elemBitSize % 8 != 0)
+    bool isSubByte = elemBitSize < 8;
+    uint64_t wScaleFactor = 1;
+
+    if (!isSubByte && (elemBitSize % 8 != 0))
       return rewriter.notifyMatchFailure(
           op, "Expected element type bit width to be multiple of 8.");
+    auto tileW = tdescTy.getDimSize(tileRank - 1);
+    // For sub byte types, only 4bits are currently supported.
+    if (isSubByte) {
+      if (elemBitSize != 4)
+        return rewriter.notifyMatchFailure(
+            op, "Only sub byte types of 4bits are supported.");
+      if (tileRank != 2)
+        return rewriter.notifyMatchFailure(
+            op, "Sub byte types are only supported for 2D tensor descriptors.");
+      auto subByteFactor = 8 / elemBitSize;
+      auto sub16BitFactor = subByteFactor * 2;
+      auto sub32BitFactor = sub16BitFactor * 2;
+      auto tileH = tdescTy.getDimSize(0);
+      if (tileW == executionSize * sub16BitFactor) {
+        // Usage case for loading as Matrix A operand
+        // Emulate with 16bit loads/stores.
+        //   scaled_tileW = executionSize
+        elemType = rewriter.getIntegerType(16);
+        tileW = executionSize;
+        wScaleFactor = sub16BitFactor;
+      } else if (tileW == executionSize * sub32BitFactor) {
+        // Usage case for loading as pre-packed Matrix B operand
+        // Emulate with 32bit loads/stores.
+        //  scaled_tileW = executionSize
+        elemType = rewriter.getIntegerType(32);
+        tileW = executionSize;
+        wScaleFactor = sub32BitFactor;
+      } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+        if (!(tileH == systolicDepth * 4 &&
+              tileW == executionSize * subByteFactor)) {
+          return rewriter.notifyMatchFailure(
+              op, "Unsupported tile shape for sub byte types.");
+        }
+        const bool vnni = op.getPacked().value_or(false);
+        if (!vnni) {
+          return rewriter.notifyMatchFailure(
+              op, "Unsupported tile shape for sub byte types without pack.");
+        }
+        // Usage case for loading as Matrix B with pack request.
+        // source is assumed to pre-packed into 8bit elements
+        // Emulate with 8bit loads with pack request.
+        //   scaled_tileW = executionSize
+        elemType = rewriter.getIntegerType(8);
+        tileW = executionSize;
+        wScaleFactor = subByteFactor;
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported tile width for sub byte types.");
+      }
+      // recompute element bit size for emulation.
+      elemBitSize = elemType.getIntOrFloatBitWidth();
+    }
 
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
@@ -302,12 +365,24 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
       // Convert base pointer (i64) to LLVM pointer type.
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+      // FIXME: width or pitch is not the same as baseShapeW it should be the
+      // stride of the second to last dimension in row major layout.
       // Compute width in bytes.
       Value surfaceW =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
 
-      // Get tile width from the tensor descriptor type.
-      auto tileW = tdescTy.getDimSize(tileRank - 1);
+      if (wScaleFactor > 1) {
+        // Scale baseShapeW, offsetW, surfaceW for sub byte emulation.
+        // Note: tileW is already scaled above.
+        Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+            rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+        baseShapeW = arith::ShRSIOp::create(rewriter, loc, baseShapeW,
+                                            wScaleFactorValLog2);
+        offsetW =
+            arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
+        surfaceW = arith::ShRSIOp::create(rewriter, loc, surfaceW,
+                                          wScaleFactorValLog2);
+      }
       // Get tile height from the tensor descriptor type.
       auto tileH = tdescTy.getDimSize(0);
       // Get vblocks from the tensor descriptor type.
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..a6471414ec422
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_sub_byte.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+    // CHECK-LABEL: gpu.func @load_store_matrix_a
+    // CHECK-SAME: %[[ARG0:.*]]: memref<16x128xf4E2M1FN, 1>, %[[ARG1:.*]]: memref<16x128xf4E2M1FN, 1>
+    gpu.func @load_store_matrix_a(%src: memref<16x128xf4E2M1FN, 1>, %dst: memref<16x128xf4E2M1FN, 1>) kernel {
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+        // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+        // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[ARG0]]
+        // CHECK: %[[SRCINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]]
+        // CHECK: %[[SRCPTR64:.*]] = arith.index_castui %[[SRCINDEX]] : index to i64
+        %srcce = memref.memory_space_cast %src : memref<16x128xf4E2M1FN, 1> to memref<16x128xf4E2M1FN>
+        // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[ARG1]]
+        // CHECK: %[[DSTINDEX:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]]
+        // CHECK: %[[DSTPTR64:.*]] = arith.index_castui %[[DSTINDEX]] : index to i64
+        %dstte = memref.memory_space_cast %dst : memref<16x128xf4E2M1FN, 1> to memref<16x128xf4E2M1FN>
+
+        // CHECK: %[[PAYLOAD_SRC:.*]] = vector.insert %[[SRCPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[BITCAST1_SRC:.*]] = vector.bitcast %[[PAYLOAD_SRC]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[PAYLOAD1_SRC:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_SRC]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD2_SRC:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_SRC]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD3_SRC:.*]] = vector.insert %[[C0_I32]], %[[PAYLOAD2_SRC]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD4_SRC:.*]] = vector.insert %[[C0_I32]], %[[PAYLOAD3_SRC]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xf4E2M1FN> -> !xegpu.tensor_desc<8x64xf4E2M1FN>
+
+        // CHECK: %[[BITCAST2:.*]] = vector.bitcast %[[PAYLOAD4_SRC]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[SRCPTR64:.*]] = vector.extract %[[BITCAST2]][0] : i64 from vector<4xi64>
+        // CHECK: %[[SRCLLVMPTR:.*]] = llvm.inttoptr %[[SRCPTR64]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRCLLVMPTR]], %[[C64_I32]],
+        // CHECK-SAME: %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]] <{
+        // CHECK-SAME:   cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        // CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+        %loaded = xegpu.load_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x64xf4E2M1FN> -> vector<32xf4E2M1FN>
+
+        // CHECK: %[[PAYLOAD_DST:.*]] = vector.insert %[[DSTPTR64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[BITCAST1_DST:.*]] = vector.bitcast %[[PAYLOAD_DST]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[PAYLOAD1_DST:.*]] = vector.insert %[[C128_I32]], %[[BITCAST1_DST]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD2_DST:.*]] = vector.insert %[[C16_I32]], %[[PAYLOAD1_DST]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD3_DST:.*]] = vector.insert %[[C0_I32]], %[[PAYLOAD2_DST]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[PAYLOAD4_DST:.*]] = vector.insert %[[C0_I32]], %[[PAYLOAD3_DST]] [5] : i32 into vector<8xi32>
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x128xf4E2M1FN> -> !xegpu.tensor_desc<8x64xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        // CHECK: %[[BITCAST2_DST:.*]] = vector.bitcast %[[PAYLOAD4_DST]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[DSTPTR64:.*]] = vector.extract %[[BITCAST2_DST]][0] : i64 from vector<4xi64>
+        // CHECK: %[[DSTLLVMPTR:.*]] = llvm.inttoptr %[[DSTPTR64]] : i64 to !llvm.ptr<1>
+        // CHECK: xevm.blockstore2d %[[DSTLLVMPTR]], %[[C64_I32]], %[[C16_I32]],
+        // CHECK-SAME:  %[[C64_I32]], %[[C16_I32]], %[[C8_I32]], %[[LOADED]] <{
+        // CHECK-SAME:    cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:    tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
+        xegpu.store_nd %loaded, %dst_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<32xf4E2M1FN>, !xegpu.tensor_desc<8x64xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+
+    // CHECK-LABEL: gpu.func @load_store_matrix_b_prepacked
+    gpu.func @load_store_matrix_b_prepacked(%src: memref<16x256xf4E2M1FN, 1>, %dst: memref<16x256xf4E2M1FN, 1>) kernel {
+        // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        %srcce = memref.memory_space_cast %src : memref<16x256xf4E2M1FN, 1> to memref<16x256xf4E2M1FN>
+        %dstte = memref.memory_space_cast %dst : memref<16x256xf4E2M1FN, 1> to memref<16x256xf4E2M1FN>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x256xf4E2M1FN> -> !xegpu.tensor_desc<8x128xf4E2M1FN>
+
+        // CHECK: xevm.blockload2d %{{.*}}, %[[C128_I32]], %[[C16_I32]], %[[C128_I32]], %[[C16_I32]], %[[C8_I32]] <{
+        // CHECK-SAME:   cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        // CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[8, 128] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x128xf4E2M1FN> -> vector<64xf4E2M1FN>
+
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<16x256xf4E2M1FN> -> !xegpu.tensor_desc<8x128xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        // CHECK: xevm.blockstore2d %{{.*}}, %[[C128_I32]], %[[C16_I32]], %[[C128_I32]], %[[C16_I32]], %[[C8_I32]], %{{.*}} <{
+        // CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        xegpu.store_nd %loaded, %dst_tdesc[8, 128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<64xf4E2M1FN>, !xegpu.tensor_desc<8x128xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+
+    // CHECK-LABEL: gpu.func @load_store_matrix_b_request_pack
+    gpu.func @load_store_matrix_b_request_pack(%src: memref<64x128xf4E2M1FN, 1>, %dst: memref<64x128xf4E2M1FN, 1>) kernel {
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+        %srcce = memref.memory_space_cast %src : memref<64x128xf4E2M1FN, 1> to memref<64x128xf4E2M1FN>
+        %dstte = memref.memory_space_cast %dst : memref<64x128xf4E2M1FN, 1> to memref<64x128xf4E2M1FN>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<64x128xf4E2M1FN> -> !xegpu.tensor_desc<32x32xf4E2M1FN>
+
+        // CHECK: xevm.blockload2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C16_I32]], %[[C32_I32]] <{
+        // CHECK-SAME:   cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 8 : i32,
+        // CHECK-SAME:   pack_register = true, tile_height = 32 : i32, tile_width = 16 : i32, transpose = false,
+        // CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[32, 32] <{packed, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<32x32xf4E2M1FN> -> vector<64xf4E2M1FN>
+
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<64x128xf4E2M1FN> -> !xegpu.tensor_desc<8x128xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        // CHECK: xevm.blockstore2d %{{.*}}, %[[C64_I32]], %[[C64_I32]], %[[C64_I32]], %[[C0_I32]], %[[C32_I32]], %{{.*}} <{
+        // CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        xegpu.store_nd %loaded, %dst_tdesc[32, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<64xf4E2M1FN>, !xegpu.tensor_desc<8x128xf4E2M1FN, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
new file mode 100644
index 0000000000000..e6fedce3c1f74
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd_sub_byte.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @prefetch_check {
+    // CHECK-LABEL: gpu.func @prefetch_matrix_a
+    gpu.func @prefetch_matrix_a(%src: memref<16x128xf4E2M1FN, 1>) kernel {
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        %srcce = memref.memory_space_cast %src : memref<16x128xf4E2M1FN, 1> to memref<16x128xf4E2M1FN>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x128xf4E2M1FN> -> !xegpu.tensor_desc<8x64xf4E2M1FN>
+
+        // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C64_I32]], %[[C16_I32]], %[[C64_I32]], %[[C16_I32]], %[[C8_I32]]
+        // CHECK-SAME:  <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 16 : i32,
+        // CHECK-SAME:    tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
+        xegpu.prefetch_nd %src_tdesc[8, 64] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x64xf4E2M1FN>
+
+        gpu.return
+    }
+
+    // CHECK-LABEL: gpu.func @prefetch_matrix_b_prepacked
+    gpu.func @prefetch_matrix_b_prepacked(%src: memref<16x256xf4E2M1FN, 1>) kernel {
+        // CHECK: %[[C128_I32:.*]] = arith.constant 128 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        %srcce = memref.memory_space_cast %src : memref<16x256xf4E2M1FN, 1> to memref<16x256xf4E2M1FN>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x256xf4E2M1FN> -> !xegpu.tensor_desc<8x128xf4E2M1FN>
+
+        // CHECK: xevm.blockprefetch2d %{{.*}}, %[[C128_I32]], %[[C16_I32]], %[[C128_I32]], %[[C16_I32]], %[[C8_I32]]
+        // CHECK-SAME:  <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:    tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> : (!llvm.ptr<1>
+        xegpu.prefetch_nd %src_tdesc[8, 128] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x128xf4E2M1FN>
+
+        gpu.return
+    }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169099


More information about the Mlir-commits mailing list