[Mlir-commits] [mlir] 6026ca3 - [mlir][XeGPU] add unroll patterns for load_matrix and store_matrix (#154637)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 3 11:56:44 PDT 2025


Author: Chao Chen
Date: 2025-09-03T13:56:41-05:00
New Revision: 6026ca301d0759031736c0b4a6642ea94f922287

URL: https://github.com/llvm/llvm-project/commit/6026ca301d0759031736c0b4a6642ea94f922287
DIFF: https://github.com/llvm/llvm-project/commit/6026ca301d0759031736c0b4a6642ea94f922287.diff

LOG: [mlir][XeGPU] add unroll patterns for load_matrix and store_matrix (#154637)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
    mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3a88dae041dd1..ddf6b4ac85a90 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -67,8 +67,8 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> {
     to a hardware instruction.
   }];
   let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
+      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect",
+      "index::IndexDialect"];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index bad734dbfd9f0..04cfd58d846a7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -20,6 +20,7 @@ class OpResult;
 class OpBuilder;
 class ValueRange;
 class TypeConverter;
+class OpFoldResult;
 
 namespace xegpu {
 class DistributeLayoutAttr;
@@ -143,6 +144,11 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op,
 /// if no GPU module parent or XeVM target attribute exists.
 std::optional<std::string> getChipStr(Operation *op);
 
+/// Generates element-wise addition ops of two arrays with same length.
+SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
+                                         ArrayRef<OpFoldResult> lhs,
+                                         ArrayRef<OpFoldResult> rhs);
+
 /// Generates element-wise addition ops of two arrays with automatic alignment.
 /// When the input arrays have 
diff erent sizes, the shorter array is
 /// right-aligned with the longer array, and the unmatched leading elements from
@@ -156,7 +162,6 @@ std::optional<std::string> getChipStr(Operation *op);
 SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
-
 } // namespace xegpu
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 9ee002ede7838..5d5ff69e06886 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
@@ -157,10 +158,10 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
 std::optional<SmallVector<int64_t>>
 XeGPUBlockingPass::getTileShape(Operation *op) const {
   if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
-          xegpu::UpdateOffsetOp>(op))
+          xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
     return getTileShape(op->getOpResult(0));
   if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
-          xegpu::LoadGatherOp>(op))
+          xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
     return getTileShape(op->getOpOperand(0));
   if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
     return getTileShape(op->getOpOperand(1));

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c793b71639e86..d24d82780ebaa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -682,13 +682,90 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
   }
 };
 
+struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
+  using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getType();
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
+      return failure();
+
+    Type elemTy = valueTy.getElementType();
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
+
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(shape, *targetShape)) {
+      auto adds = xegpu::addElementwise(
+          rewriter, loc, mixedOffsets,
+          getAsIndexOpFoldResult(op.getContext(), offsets));
+      offsetsList.push_back(adds);
+    }
+
+    SmallVector<Value> newOps;
+    layout = layout.dropInstData();
+    for (SmallVector<OpFoldResult> offsets : offsetsList) {
+      auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+          op.getLoc(), newValueTy, op.getMemDesc(), offsets, layout);
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
+  using UnrollPattern<xegpu::StoreMatrixOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreMatrixOp op,
+                                PatternRewriter &rewriter) const override {
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getData().getType();
+    ArrayRef<int64_t> shape = valueTy.getShape();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getData(), convertedValTypes, *targetShape, loc, rewriter);
+
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(shape, *targetShape)) {
+      auto adds = xegpu::addElementwise(
+          rewriter, loc, mixedOffsets,
+          getAsIndexOpFoldResult(op.getContext(), offsets));
+      offsetsList.push_back(adds);
+    }
+
+    for (auto [v, offsets] : llvm::zip_equal(convertedValues, offsetsList))
+      rewriter.create<xegpu::StoreMatrixOp>(loc, v, op.getMemDesc(), offsets,
+                                            layout.dropInstData());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
-  patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
-               UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
-               UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
-                                                       options);
+  patterns
+      .add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
+           UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
+           UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
+           UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
+          patterns.getContext(), options);
 }

diff  --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index cac1ffe4d3bc3..b72d5648b29f9 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -134,6 +134,14 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
       return getDistributeLayoutAttr(loadNd.getTensorDesc());
 
+    // for LoadMatrixOp, the layout is attached to the property of the op
+    if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
+      return loadOp.getLayoutAttr();
+
+    // for StoreMatrixOp, the layout is attached to the property of the op
+    if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
+      return storeOp.getLayoutAttr();
+
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -154,6 +162,13 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
 xegpu::DistributeLayoutAttr
 xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
+
+  if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
+    return loadOp.getLayoutAttr();
+
+  if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
+    return storeOp.getLayoutAttr();
+
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
     return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
@@ -182,6 +197,9 @@ template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
 void xegpu::setDistributeLayoutAttrs(
     Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {
+    if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(nestOp))
+      return;
+
     for (OpOperand &opr : nestOp->getOpOperands()) {
       auto layout = getLayoutImpl(opr.get());
       setDistributeLayoutAttr(opr, layout);
@@ -429,6 +447,21 @@ std::optional<std::string> xegpu::getChipStr(Operation *op) {
   return std::nullopt;
 }
 
+/// Generates element-wise addition ops of two arrays with same length.
+SmallVector<OpFoldResult> xegpu::addElementwise(OpBuilder &builder,
+                                                Location loc,
+                                                ArrayRef<OpFoldResult> lhs,
+                                                ArrayRef<OpFoldResult> rhs) {
+  assert(lhs.size() == rhs.size() && "lhs and rhs must have the same size");
+  SmallVector<OpFoldResult> results;
+  for (auto [l, r] : llvm::zip_equal(lhs, rhs)) {
+    auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
+    auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
+    results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
+  }
+  return results;
+}
+
 /// Generates element-wise addition ops of two arrays with automatic alignment.
 /// When the input arrays have 
diff erent sizes, the shorter array is
 /// right-aligned with the longer array, and the unmatched leading elements from
@@ -448,11 +481,6 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
   ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
   SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
   a = a.slice(a.size() - b.size());
-  for (auto [l, r] : llvm::zip(a, b)) {
-    auto lval = getValueOrCreateConstantIndexOp(builder, loc, l);
-    auto rval = getValueOrCreateConstantIndexOp(builder, loc, r);
-    results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval));
-  }
+  results.append(addElementwise(builder, loc, a, b));
   return results;
-  return {};
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index d986e5bd1cfb4..9d63c2ddd4895 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -561,3 +561,26 @@ gpu.module @test_kernel {
     gpu.return %e : vector<8x32x2xf16>
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  //CHECK-LABEL: unroll_load_matrix
+  gpu.func @unroll_load_matrix(%arg0: memref<4096xi8, 3>) -> vector<32x32xf32> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+    //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
+    //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+    %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<inst_data = [8, 16]>}>: !xegpu.mem_desc<32x32xf32> -> vector<32x32xf32>
+    gpu.return %1: vector<32x32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: unroll_store_matrix
+  gpu.func @unroll_store_matrix(%value: vector<32x32xf32>, %arg0 : memref<32768xi8, 3>) {
+    %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    // CHECK-COUNT-8:  xegpu.store_matrix {{.*}} : vector<8x16xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+    xegpu.store_matrix %value, %mdesc[0, 0] {layout = #xegpu.layout<inst_data = [8, 16]>} : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>
+    gpu.return
+  }
+}

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index f4a49da71605f..c0fb373835e3d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -26,6 +26,33 @@ gpu.module @test_1_1_assignment {
     gpu.return
   }
 
+  // CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref
+  // CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32>
+  gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) {
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_2:%.+]] = arith.constant 0 : index
+    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_2]] : index
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_3]]
+    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
+    //CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_4]]
+    //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[Y]], [[X]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {


        


More information about the Mlir-commits mailing list