[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll pass for XeGPU (PR #137010)

Chao Chen llvmlistbot at llvm.org
Tue Apr 29 13:36:22 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/137010

>From 7d332da6ff21af66e87b39c0c9bae299c292bb6c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 17 Apr 2025 17:54:04 +0000
Subject: [PATCH 01/12] init

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td | 11 +++++++++++
 mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt     |  1 +
 2 files changed, 12 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..007dd81d1dfac 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -38,4 +38,15 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   ];
 }
 
+def XeGPUUnroll: Pass<"xegpu-unroll"> {
+  let summary = "Unroll operations into smaller shapes";
+  let description = [{
+    The pass unrolls operations into smaller shapes that can be distribute
+    to an SIMD instruction.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 9f041aae511df..82aeb3a80c50c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
+  XeGPUUnroll.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

>From 47f9b3d5199ff1db2c0a0dabe5819999e7356cd9 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 18 Apr 2025 21:21:38 +0000
Subject: [PATCH 02/12] add patterns for createNdOp and StoreNdOp

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  22 +
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   7 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  43 +-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 403 ++++++++++++++++++
 4 files changed, 442 insertions(+), 33 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f1bed70253ef3..cab9fffdbbcd2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
         return $_get($_ctxt, sg_layout, sg_data, inst_data,
                      DenseI32ArrayAttr::get($_ctxt, lane_layout),
                      DenseI32ArrayAttr::get($_ctxt, lane_data), order);
+      }]>,
+    AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
+                     "llvm::ArrayRef<int>": $lane_data,
+                     "llvm::ArrayRef<int>": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     DenseI32ArrayAttr::get($_ctxt, lane_layout),
+                     DenseI32ArrayAttr::get($_ctxt, lane_data),
+                     DenseI32ArrayAttr::get($_ctxt, order));
+      }]>,
+    AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
+                     "DenseI32ArrayAttr": $lane_data,
+                     "DenseI32ArrayAttr": $order),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto inst_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data, inst_data,
+                     lane_layout, lane_data, order);
       }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..627de858d94aa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape,
-                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
-
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+    OpBuilder<(ins "Type": $tdesc, "Value": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
                    "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e0e25365220b5..70f32314c67ce 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -169,46 +169,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<MemRefType> source,
+                           Type tdesc, Value source,
                            llvm::ArrayRef<OpFoldResult> offsets,
                            llvm::ArrayRef<OpFoldResult> shape,
                            llvm::ArrayRef<OpFoldResult> strides) {
   assert(shape.size() && offsets.size() && strides.size() &&
          shape.size() == strides.size() && shape.size() == offsets.size());
 
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<int64_t> staticShape;
-  llvm::SmallVector<int64_t> staticStrides;
+  auto intTy = dyn_cast<IntegerType>(source.getType());
+  auto memrefTy = dyn_cast<MemRefType>(source.getType());
+  assert(intTy || memrefTy && "Source has to be either int or memref.");
+
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<Value> dynamicShape;
   llvm::SmallVector<Value> dynamicStrides;
 
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
-  dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
-  auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
-  build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
-        dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
-                           Type tdesc, TypedValue<IntegerType> source,
-                           llvm::ArrayRef<OpFoldResult> offsets,
-                           llvm::ArrayRef<OpFoldResult> shape,
-                           llvm::ArrayRef<OpFoldResult> strides) {
-  assert(shape.size() && offsets.size() && strides.size() &&
-         shape.size() == strides.size() && shape.size() == offsets.size());
-
   llvm::SmallVector<int64_t> staticOffsets;
   llvm::SmallVector<int64_t> staticShape;
   llvm::SmallVector<int64_t> staticStrides;
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<Value> dynamicShape;
-  llvm::SmallVector<Value> dynamicStrides;
 
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -218,6 +196,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
   auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
 
+  if (memrefTy) {
+    auto memrefShape = memrefTy.getShape();
+    auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
+
+    // if shape and strides are from Memref, we don't need attributes for them
+    if (staticShape == memrefShape && staticStrides == memrefStrides) {
+      staticShapeAttr = DenseI64ArrayAttr();
+      staticStridesAttr = DenseI64ArrayAttr();
+    }
+  }
+
   build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
         dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
new file mode 100644
index 0000000000000..6148052401c97
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -0,0 +1,403 @@
+//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUUNROLL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+static const char *const packAttrName = "__xetile_blocking_pack__";
+static const char *const unpackAttrName = "__xetile_blocking_unpack__";
+static const char *const blockAttrName = "__xetile_blocking_inner_block__";
+
+// emulate the the unpack behavior using insert_strided_slice for VectorType
+// values and unrealized_conversion_cast for TileType values.
+static Value addUnpackOp(ValueRange srcs, Type destTy,
+                         llvm::ArrayRef<int64_t> innerBlock, Location loc,
+                         PatternRewriter &rewriter) {
+  if (auto vecTy = dyn_cast<VectorType>(destTy)) {
+    assert(vecTy.getRank() == 2 && innerBlock.size() == 2 &&
+           "Expecting innerBlock size to match the rank of destTy.");
+    auto shape = vecTy.getShape();
+    auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
+    int64_t idx = 0;
+    for (int64_t i = 0; i < shape[0]; i += innerBlock[0]) {
+      for (int64_t j = 0; j < shape[1]; j += innerBlock[1]) {
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, srcs[idx++], result, llvm::ArrayRef<int64_t>({i, j}),
+            llvm::ArrayRef<int64_t>({1, 1}));
+      }
+    }
+    return result;
+
+  } else if (isa<xegpu::TensorDescType>(destTy)) {
+    auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
+                               rewriter.getUnitAttr());
+    auto innerBlkAttr =
+        NamedAttribute(rewriter.getStringAttr(blockAttrName),
+                       rewriter.getDenseI64ArrayAttr(innerBlock));
+    auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+        loc, destTy, srcs,
+        llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
+    return castOp.getResult(0);
+  }
+
+  llvm_unreachable("Unexpected destTy.");
+  return Value();
+}
+
+// emulate the the pack behavior using extract_strided_slice for VectorType
+// values and unrealized_conversion_cast for TensorDescType values.
+static llvm::SmallVector<Value> addPackOp(Value src, TypeRange destTypes,
+                                          llvm::ArrayRef<int64_t> innerBlock,
+                                          Location loc,
+                                          PatternRewriter &rewriter) {
+  if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
+    assert(vecTy.getRank() == 2 && innerBlock.size() == 2 &&
+           "Expecting innerBlock size to match the rank of src.");
+    auto shape = vecTy.getShape();
+    llvm::SmallVector<Value> results;
+    for (int64_t i = 0; i < shape[0]; i += innerBlock[0]) {
+      for (int64_t j = 0; j < shape[1]; j += innerBlock[1]) {
+        auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, src, llvm::ArrayRef<int64_t>({i, j}), innerBlock,
+            llvm::ArrayRef<int64_t>({1, 1}));
+        results.push_back(slice);
+      }
+    }
+    return results;
+  } else if (isa<xegpu::TensorDescType>(src.getType())) {
+    auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
+                               rewriter.getUnitAttr());
+    auto innerBlkAttr =
+        NamedAttribute(rewriter.getStringAttr(blockAttrName),
+                       rewriter.getDenseI64ArrayAttr(innerBlock));
+    auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+        loc, destTypes, src,
+        llvm::ArrayRef<NamedAttribute>({attr, innerBlkAttr}));
+    return castOp.getResults();
+  }
+
+  llvm_unreachable("Unexpected src type.");
+  return llvm::SmallVector<Value>();
+}
+
+template <typename SourceOp>
+struct UnrollPattern : public OpRewritePattern<SourceOp> {
+  UnrollPattern(MLIRContext *context,
+                const vector::UnrollVectorOptions &options,
+                PatternBenefit benefit = 1)
+      : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
+
+protected:
+  std::optional<SmallVector<int64_t>>
+  getTargetShape(const vector::UnrollVectorOptions &options,
+                 Operation *op) const {
+    LDBG("");
+    LDBG("Get unroll shape for: " << *op);
+    assert(options.nativeShape &&
+           "expects the native shape for native shape call back function.");
+    auto nativeShape = options.nativeShape(op);
+    return nativeShape;
+  }
+
+  std::optional<SmallVector<int64_t>>
+  computeGrids(llvm::ArrayRef<int64_t> shape,
+               llvm::ArrayRef<int64_t> subShape) const {
+    // if the shape == subshape, we don't need to unroll.
+    if (shape == subShape)
+      return std::nullopt;
+    return computeShapeRatio(shape, subShape);
+  }
+
+  bool isUnrollable(Attribute attr) const {
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
+    return layout && layout.isSgLayout();
+  }
+
+  xegpu::LayoutAttr getLaneLayoutAttr(Attribute attr) const {
+    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
+    if (!layout)
+      return xegpu::LayoutAttr();
+    return xegpu::LayoutAttr::get(
+        layout.getContext(), nullptr /* sg_layout */, nullptr /* sg_data */,
+        nullptr /* inst_data */, layout.getLaneLayout(), layout.getLaneData(),
+        layout.getOrder());
+  }
+
+  vector::UnrollVectorOptions options;
+};
+
+struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
+  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctx = op.getContext();
+    auto tdescTy = op.getType();
+    auto shape = tdescTy.getShape();
+    auto layout = tdescTy.getLayout();
+
+    if (!isUnrollable(layout))
+      return failure();
+
+    auto maybeTargetShape = getTargetShape(options, op);
+    if (!maybeTargetShape)
+      return failure();
+    auto targetShape = *maybeTargetShape;
+
+    auto maybeGrids = computeGrids(shape, targetShape);
+    if (!maybeGrids)
+      return failure();
+    auto grids = *maybeGrids;
+
+    auto encoding = tdescTy.getEncoding();
+    auto newLayout = getLaneLayoutAttr(layout);
+    auto newTdescTy = xegpu::TensorDescType::get(
+        ctx, targetShape, tdescTy.getElementType(), encoding, newLayout);
+
+    auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+      auto maybeInt = getConstantIntValue(a);
+      if (maybeInt) {
+        return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
+      } else {
+        auto aV = llvm::cast<Value>(a);
+        auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
+        return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+      }
+    };
+
+    auto mixedOffsets = op.getMixedOffsets();
+    // For n-D memrefs where n > 2, we need to handle the last two
+    // dimensions, and keep the first n-2 dimensions as is.
+    int64_t x = mixedOffsets.size() - 2;
+    int64_t y = mixedOffsets.size() - 1;
+    OpFoldResult oldX = mixedOffsets[x];
+    OpFoldResult oldY = mixedOffsets[y];
+
+    SmallVector<Value> newOps;
+    for (int64_t i = 0; i < grids[0]; i++) {
+      for (int64_t j = 0; j < grids[1]; j++) {
+        auto subOffX = targetShape[0] * i;
+        auto subOffY = targetShape[1] * j;
+        mixedOffsets[x] = addi(oldX, subOffX);
+        mixedOffsets[y] = addi(oldY, subOffY);
+        auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), op.getMixedStrides());
+        newOps.push_back(newOp);
+      }
+    }
+    auto castOp = addUnpackOp(newOps, tdescTy, targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
+struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
+  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
+  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
+  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
+                                PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctx = op.getContext();
+    auto valueTy = op.getValueType();
+    auto tdescTy = op.getTensorDescType();
+    auto layout = tdescTy.getLayout();
+
+    if (!isUnrollable(layout))
+      return failure();
+
+    auto maybeTargetShape = getTargetShape(options, op);
+    if (!maybeTargetShape)
+      return failure();
+    auto targetShape = *maybeTargetShape;
+
+    auto maybeGrids = computeGrids(tdescTy.getShape(), targetShape);
+    if (!maybeGrids)
+      return failure();
+    auto grids = *maybeGrids;
+
+    auto elemTy = tdescTy.getElementType();
+    auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
+    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy, tdescTy.getEncoding(),
+        getLaneLayoutAttr(layout));
+
+    auto numNewOps = std::accumulate(grids.begin(), grids.end(), 1, std::multiplies<int64_t>());
+    llvm::SmallVector<Type> convertedValTypes(numNewOps, newValueTy);
+    llvm::SmallVector<Type> convertedTileTypes(numNewOps, newTdescTy);
+    auto convertedValues = addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
+    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTileTypes, targetShape, loc, rewriter);
+
+    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
+      rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
+                                           op.getL2HintAttr(),
+                                           op.getL3HintAttr());
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
+  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollLoadOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollStoreOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
+  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::DpasOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
+  using UnrollPattern<xegpu::AtomicRMWOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::AtomicRMWOp op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
+};
+
+} // namespace
+
+namespace {
+
+struct XeGPUUnrollPass final
+    : public xegpu::impl::XeGPUUnrollBase<XeGPUUnrollPass> {
+  XeGPUUnrollPass() = default;
+  XeGPUUnrollPass(const XeGPUUnrollPass &pass) = default;
+
+  void runOnOperation() override {
+    vector::UnrollVectorOptions options;
+    options.setNativeShapeFn(
+        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+          if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+            auto tdescTy = createNdOp.getType();
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              if (auto inst_data = layout.getInstData())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
+          if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+            auto tdescTy = loadNdOp.getTensorDescType();
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              if (auto inst_data = layout.getInstData())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
+          if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+            auto tdescTy = storeNdOp.getTensorDescType();
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              if (auto inst_data = layout.getInstData())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
+          return std::nullopt;
+        });
+
+    auto funcOp = getOperation();
+    RewritePatternSet patterns(&getContext());
+    patterns.add<UnrollCreateNdOp, UnrollStoreNdOp>(patterns.getContext(), options);
+
+    // GreedyRewriteConfig config;
+    // config.fold = false;
+    // config.cseConstants = false;
+    (void)applyPatternsGreedily(funcOp, std::move(patterns));
+    return;
+  }
+};
+
+} // namespace
\ No newline at end of file

>From 932747e741cc7af15f8b8ff696fc090fcff9194c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 18 Apr 2025 21:31:16 +0000
Subject: [PATCH 03/12] refine nativeShapeFn

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 26 ++++++-------------
 1 file changed, 8 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 6148052401c97..94a236cb13d95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -358,26 +358,16 @@ struct XeGPUUnrollPass final
     vector::UnrollVectorOptions options;
     options.setNativeShapeFn(
         [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
-          if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-            auto tdescTy = createNdOp.getType();
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              if (auto inst_data = layout.getInstData())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
-          if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-            auto tdescTy = loadNdOp.getTensorDescType();
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              if (auto inst_data = layout.getInstData())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
+          if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+              tdescTy = createNdOp.getType();
+            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+              tdescTy = loadNdOp.getTensorDescType();
+            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+              tdescTy = storeNdOp.getTensorDescType();
             }
-          }
 
-          if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-            auto tdescTy = storeNdOp.getTensorDescType();
             if (auto layout = tdescTy.getLayoutAttr()) {
               if (auto inst_data = layout.getInstData())
                 return SmallVector<int64_t>(inst_data.asArrayRef().begin(),

>From f843d980079638f10a378cdad2fb08eceabde5b9 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Apr 2025 16:09:23 +0000
Subject: [PATCH 04/12] refine verifier for TensorDescType

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |  4 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUDialect.td     |  6 ++
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 87 +++++++++++++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 33 +------
 mlir/test/Dialect/XeGPU/invalid.mlir          |  8 +-
 5 files changed, 86 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index d6c51d20571fd..8e2784f40ad39 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -25,12 +25,14 @@ class TensorDescType;
 } // namespace xegpu
 } // namespace mlir
 
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
 #define GET_TYPEDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
+
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
+
 #define GET_OP_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index fb5a1e6f1db0c..549018b61d6fb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -36,6 +36,12 @@ def XeGPU_Dialect : Dialect {
 
     let useDefaultTypePrinterParser = true;
     let useDefaultAttributePrinterParser = true;
+
+    let extraClassDeclaration = [{
+      /// Checks if the given shape can be evenly distributed based on the layout
+      /// and data factors provided by the LayoutAttr.
+      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+    }];
 }
 
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUDIALECT_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index b865b80f0075e..8694d2f950dd9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -30,6 +31,61 @@ void XeGPUDialect::initialize() {
       >();
 }
 
+bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
+                                         xegpu::LayoutAttr attr) {
+  assert(attr && "Layout attribute is missing.");
+
+  auto getSubShapeOrNull =
+      [&](llvm::ArrayRef<int64_t> shape, DenseI32ArrayAttr layout,
+          DenseI32ArrayAttr data,
+          bool use_rr = true) -> std::optional<SmallVector<int64_t>> {
+    llvm::SmallVector<int64_t> newShape(shape);
+    if (layout) {
+      auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
+      if (vec.size() != shape.size())
+        return std::nullopt;
+      auto ratio = computeShapeRatio(shape, vec);
+      if (!ratio.has_value())
+        return std::nullopt;
+      newShape = ratio.value();
+    }
+
+    if (data) {
+      auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
+      if (vec.size() != shape.size())
+        return std::nullopt;
+      auto ratio = computeShapeRatio(newShape, vec);
+      if (!ratio.has_value() && use_rr)
+        ratio = computeShapeRatio(vec, newShape);
+      if (!ratio.has_value())
+        return std::nullopt;
+
+      // if data is not null, we always return it for next phase.
+      newShape = vec;
+    }
+    return newShape;
+  };
+
+  // check the sgLayout and sgData
+  auto maybeSgShape =
+      getSubShapeOrNull(shape, attr.getSgLayout(), attr.getSgData());
+  if (!maybeSgShape)
+    return false;
+  auto sgShape = maybeSgShape.value();
+
+  // check InstData, it neither have layout nor need round-robin
+  auto maybeInstShape =
+      getSubShapeOrNull(sgShape, nullptr, attr.getInstData(), false);
+  if (!maybeInstShape)
+    return false;
+  auto instShape = maybeInstShape.value();
+
+  // check LaneLayout and LaneData
+  auto maybeLaneShape = getSubShapeOrNull(instShape, attr.getLaneLayout(),
+                                          attr.getLaneData(), false);
+  return maybeLaneShape.has_value();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -241,7 +297,7 @@ LogicalResult TensorDescType::verify(
     llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
     mlir::Attribute encoding, mlir::Attribute layout) {
   size_t rank = shape.size();
-  // Low-pressure types are packed in 32-bit units.
+  // Low-precision types are packed in 32-bit units.
   int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth();
   if (rank != 1 && rank != 2)
     return emitError() << "expected 1D or 2D tensor";
@@ -268,23 +324,21 @@ LogicalResult TensorDescType::verify(
     }
   }
 
-  if (auto blockAttr =
-          mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
+  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
+  if (blockAttr) {
     MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
     if (rank == 2 && memorySpaceAttr &&
         memorySpaceAttr.getValue() == MemorySpace::SLM)
       return emitError() << "SLM is not supported for 2D block tensor";
   }
 
-  if (auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout)) {
-
+  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
+  if (layoutAttr) {
     if (rank != (size_t)layoutAttr.getRank())
       return emitError() << "expected layout rank to match tensor rank";
 
-    ArrayRef<int32_t> laneLayout = layoutAttr.getLaneLayout().asArrayRef();
-    ArrayRef<int32_t> laneData = layoutAttr.getLaneData().asArrayRef();
-
-    if (scatterAttr) {
+    auto laneData = layoutAttr.getLaneData();
+    if (scatterAttr && laneData) {
       // Validate subgroup mapping rules for scattered tensors.
       // A work-item's slice of the tensor with shape [sg_size] or
       // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
@@ -294,20 +348,19 @@ LogicalResult TensorDescType::verify(
       if (rank > 1 && laneData[0] != 1)
         return emitError()
                << "cannot map over non-contiguous scattered row elements";
-      if (laneData.back() != packingFactor)
+      if (laneData[rank - 1] != packingFactor)
         return emitError() << "work item data mapping must match the number of "
                               "contiguous elements";
     }
 
-    for (size_t i = 0; i < shape.size(); ++i) {
-      uint32_t numElemPerWi = laneLayout[i] * laneData[i];
-      if (shape[i] < numElemPerWi || shape[i] % numElemPerWi != 0)
-        return emitError() << "cannot distribute " << shape[i] << " over "
-                           << laneLayout[i] << " work items with "
-                           << laneData[i] << " elements each";
+    if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+      std::string shapeStr;
+      llvm::raw_string_ostream stream(shapeStr);
+      llvm::interleaveComma(shape, stream);
+      return emitError() << "cannot distribute [" << shapeStr << "] using "
+                         << layoutAttr;
     }
   }
-
   return success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 70f32314c67ce..540cae1028102 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
@@ -73,34 +74,6 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
          kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
 }
 
-// Checks if the given shape is evenly distributed based on the layout
-// and data factors provided by the LayoutAttr. The function ensures that
-// each dimension of the shape can be evenly divided by the corresponding
-// data factor, and the resulting quotient can be evenly divided by the
-// layout factor. Returns `true` if the shape is evenly distributed,
-// otherwise `false`.
-static bool isEvenDistributed(llvm::ArrayRef<int64_t> shape,
-                              xegpu::LayoutAttr attr) {
-  assert(attr && "Layout attribute is missing.");
-  llvm::SmallVector<int32_t> defaults(shape.size(), 1);
-  llvm::ArrayRef<int32_t> layout, data;
-  if (auto sg_layout = attr.getSgLayout()) {
-    layout = sg_layout.asArrayRef();
-    auto sg_data = attr.getSgData();
-    data = sg_data ? sg_data.asArrayRef() : defaults;
-  } else {
-    layout = attr.getLaneLayout().asArrayRef();
-    auto lane_data = attr.getLaneData();
-    data = lane_data ? lane_data.asArrayRef() : defaults;
-  }
-  for (auto [dimSize, dataFactor, layoutFactor] :
-       llvm::zip_equal(shape, data, layout)) {
-    if (dimSize % dataFactor != 0 || (dimSize / dataFactor) % layoutFactor != 0)
-      return false;
-  }
-  return true;
-}
-
 static LogicalResult
 isValidGatherScatterParams(Type maskTy, VectorType valueTy,
                            TensorDescType tdescTy, UnitAttr transposeAttr,
@@ -674,10 +647,10 @@ LogicalResult ConvertLayoutOp::verify() {
         "expected srcMap and resMap be WgLayout or SgLayout at the same time.");
 
   auto shape = getSource().getType().getShape();
-  if (!isEvenDistributed(shape, srcMap))
+  if (!XeGPUDialect::isEvenlyDistributable(shape, srcMap))
     return emitOpError("invalid srcMap, data cannot be evenly distributed.");
 
-  if (!isEvenDistributed(shape, resMap))
+  if (!XeGPUDialect::isEvenlyDistributable(shape, resMap))
     return emitOpError("invalid resMap, data cannot be evenly distributed.");
 
   return mlir::success();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 67ed89e11b4c9..2fd4d6280649c 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -404,7 +404,7 @@ func.func @tensor_desc_1D_invalid_map_data(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot distribute 8 over 16 work items with 1 elements each}}
+      // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}}
       !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
   return
 }
@@ -412,7 +412,7 @@ func.func @tensor_desc_invalid_map_layout(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+      // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}}
       !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
   return
 }
@@ -420,7 +420,7 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot distribute 4 over 2 work items with 4 elements each}}
+      // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>}}
       !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
   return
 }
@@ -428,7 +428,7 @@ func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
 // -----
 func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot distribute 4 over 8 work items with 1 elements each}}
+      // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 2]>}}
       !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 2]>>
   return
 }

>From c6bdd3c1440a72b6f9990d22b97ed4b43645f5e3 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Apr 2025 18:29:23 +0000
Subject: [PATCH 05/12] add loadNd pattern

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  2 +-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 57 ++++++++++++++++---
 2 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index cab9fffdbbcd2..2873ae619e65c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -284,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     bool isSgLayout() {
-      return getSgLayout() == nullptr && getLaneLayout() != nullptr;
+      return !isWgLayout();
     }
 
     int64_t getRank() {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 94a236cb13d95..afad0e9105ed1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -139,12 +139,12 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
 
   bool isUnrollable(Attribute attr) const {
     auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
-    return layout && layout.isSgLayout();
+    return layout && layout.isSgLayout() && layout.getInstData() != nullptr;
   }
 
   xegpu::LayoutAttr getLaneLayoutAttr(Attribute attr) const {
     auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
-    if (!layout)
+    if (!layout || layout.getLaneLayout() == nullptr)
       return xegpu::LayoutAttr();
     return xegpu::LayoutAttr::get(
         layout.getContext(), nullptr /* sg_layout */, nullptr /* sg_data */,
@@ -233,7 +233,48 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
   using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
                                 PatternRewriter &rewriter) const override {
-    return failure();
+
+    auto loc = op.getLoc();
+    auto ctx = op.getContext();
+    auto valueTy = op.getType();
+    auto tdescTy = op.getTensorDescType();
+    auto layout = tdescTy.getLayout();
+
+    if (!isUnrollable(layout))
+      return failure();
+
+    auto maybeTargetShape = getTargetShape(options, op);
+    if (!maybeTargetShape)
+      return failure();
+    auto targetShape = *maybeTargetShape;
+
+    auto maybeGrids = computeGrids(tdescTy.getShape(), targetShape);
+    if (!maybeGrids)
+      return failure();
+    auto grids = *maybeGrids;
+
+    auto elemTy = tdescTy.getElementType();
+    auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
+    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy,
+                                                 tdescTy.getEncoding(),
+                                                 getLaneLayoutAttr(layout));
+
+    auto numNewOps = computeProduct(grids);
+    llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
+    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
+                                     targetShape, loc, rewriter);
+
+    llvm::SmallVector<Value> newOps;
+    for (auto t : convertedTdescs) {
+      auto newOp =
+          rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
+      newOps.push_back(newOp);
+    }
+
+    auto castOp = addUnpackOp(newOps, op.getType(), targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
   }
 };
 
@@ -265,11 +306,12 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy, tdescTy.getEncoding(),
         getLaneLayoutAttr(layout));
 
-    auto numNewOps = std::accumulate(grids.begin(), grids.end(), 1, std::multiplies<int64_t>());
+    auto numNewOps = computeProduct(grids);
     llvm::SmallVector<Type> convertedValTypes(numNewOps, newValueTy);
-    llvm::SmallVector<Type> convertedTileTypes(numNewOps, newTdescTy);
+    llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
     auto convertedValues = addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
-    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTileTypes, targetShape, loc, rewriter);
+    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
+                                     targetShape, loc, rewriter);
 
     for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
       rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
@@ -380,7 +422,8 @@ struct XeGPUUnrollPass final
 
     auto funcOp = getOperation();
     RewritePatternSet patterns(&getContext());
-    patterns.add<UnrollCreateNdOp, UnrollStoreNdOp>(patterns.getContext(), options);
+    patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
+        patterns.getContext(), options);
 
     // GreedyRewriteConfig config;
     // config.fold = false;

>From 1d4dc72e2b00ff4f1ef19a0f7413520363710a02 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Apr 2025 20:32:46 +0000
Subject: [PATCH 06/12] add test pass

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |  7 ++
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 12 +--
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 21 +++++
 mlir/test/lib/Dialect/CMakeLists.txt          |  1 +
 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt    | 15 ++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 86 +++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 7 files changed, 139 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
 create mode 100644 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
 create mode 100644 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 63ea26df06937..72a7ab0467aad 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -12,11 +12,18 @@
 namespace mlir {
 class RewritePatternSet;
 
+namespace vector {
+  struct UnrollVectorOptions;
+} // namespace vector
+
 namespace xegpu {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 
+void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
+                                 const vector::UnrollVectorOptions &options);
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index afad0e9105ed1..27d104db5fbbb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/Debug.h"
 #include <numeric>
@@ -387,10 +388,6 @@ struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
   }
 };
 
-} // namespace
-
-namespace {
-
 struct XeGPUUnrollPass final
     : public xegpu::impl::XeGPUUnrollBase<XeGPUUnrollPass> {
   XeGPUUnrollPass() = default;
@@ -432,5 +429,10 @@ struct XeGPUUnrollPass final
     return;
   }
 };
+} // namespace
 
-} // namespace
\ No newline at end of file
+void mlir::xegpu::populateXeGPUUnrollPatterns(
+    RewritePatternSet &patterns, const mlir::vector::UnrollVectorOptions &options) {
+  patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
+        patterns.getContext(), options);
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
new file mode 100644
index 0000000000000..825bd3ff9f042
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @test {
+  // CHECK-LABEL: test_create_nd_tdesc_vc_1
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  //CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
+  //CHECK-COUNT-6: %[[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
+  //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %data = arith.constant dense<9.0> : vector<24x32xf32>
+    %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    %add = arith.addf %data, %ld : vector<24x32xf32>
+    xegpu.store_nd %add, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd..a8fd70e6397a5 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..9223f2860ce15
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRXeGPUTestPasses
+  TestXeGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+)
+
+mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC
+  MLIRAffineUtils
+  MLIRIR
+  MLIRMemRefDialect
+  MLIRXeGPUDialect
+  MLIRPass
+  MLIRTransforms
+  MLIRGPUDialect
+)
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 0000000000000..c82a280b67d91
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,86 @@
+//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+
+namespace {
+
+struct TestXeGPUUnrollingPatterns
+    : public PassWrapper<TestXeGPUUnrollingPatterns,
+                         OperationPass<gpu::GPUModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
+
+  StringRef getArgument() const final {
+    return "test-xegpu-unrolling-patterns";
+  }
+
+  StringRef getDescription() const final {
+    return "Test lowering patterns to unroll ops in the xegpu dialect";
+  }
+
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect>();
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+  }
+
+  TestXeGPUUnrollingPatterns() = default;
+  TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
+      : PassWrapper(pass) {}
+
+  void runOnOperation() override {
+    vector::UnrollVectorOptions options;
+    options.setNativeShapeFn(
+        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+          if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+              tdescTy = createNdOp.getType();
+            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+              tdescTy = loadNdOp.getTensorDescType();
+            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+              tdescTy = storeNdOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              if (auto inst_data = layout.getInstData())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
+          return std::nullopt;
+        });
+
+
+    MLIRContext *ctx = &getContext();
+    RewritePatternSet patterns(ctx);
+
+    populateXeGPUUnrollPatterns(patterns, options);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPULowerings() {
+  PassRegistration<TestXeGPUUnrollingPatterns>();
+}
+}
+}
\ No newline at end of file
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 344576a44ca41..cdcf59b2add13 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -158,6 +158,7 @@ void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestVulkanRunnerPipeline();
 void registerTestWrittenToPass();
+void registerTestXeGPULowerings();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 void registerTestDialectConversionPasses();
 void registerTestPDLByteCodePass();
@@ -301,6 +302,7 @@ void registerTestPasses() {
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestVulkanRunnerPipeline();
   mlir::test::registerTestWrittenToPass();
+  mlir::test::registerTestXeGPULowerings();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   mlir::test::registerTestDialectConversionPasses();
   mlir::test::registerTestPDLByteCodePass();

>From 545f937af7180ea46a5b7914224e2720cb103dd4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Apr 2025 22:06:40 +0000
Subject: [PATCH 07/12] format code

---
 .../Dialect/XeGPU/Transforms/Transforms.h     |  2 +-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 61 ++++++++++---------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 51 ++++++++--------
 3 files changed, 58 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 72a7ab0467aad..4019c39b66163 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -13,7 +13,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace vector {
-  struct UnrollVectorOptions;
+struct UnrollVectorOptions;
 } // namespace vector
 
 namespace xegpu {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 27d104db5fbbb..f3c53e36f40dc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -211,7 +211,8 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
         mixedOffsets[x] = addi(oldX, subOffX);
         mixedOffsets[y] = addi(oldY, subOffY);
         auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
-          loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), op.getMixedStrides());
+            loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
+            op.getMixedStrides());
         newOps.push_back(newOp);
       }
     }
@@ -304,20 +305,21 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
 
     auto elemTy = tdescTy.getElementType();
     auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
-    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy, tdescTy.getEncoding(),
-        getLaneLayoutAttr(layout));
+    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy,
+                                                 tdescTy.getEncoding(),
+                                                 getLaneLayoutAttr(layout));
 
     auto numNewOps = computeProduct(grids);
     llvm::SmallVector<Type> convertedValTypes(numNewOps, newValueTy);
     llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
-    auto convertedValues = addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
+    auto convertedValues =
+        addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
     auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
                                      targetShape, loc, rewriter);
 
     for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
       rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
-                                           op.getL2HintAttr(),
-                                           op.getL3HintAttr());
+                                        op.getL2HintAttr(), op.getL3HintAttr());
     }
     rewriter.eraseOp(op);
     return success();
@@ -395,27 +397,27 @@ struct XeGPUUnrollPass final
 
   void runOnOperation() override {
     vector::UnrollVectorOptions options;
-    options.setNativeShapeFn(
-        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
-          if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-              tdescTy = createNdOp.getType();
-            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-              tdescTy = loadNdOp.getTensorDescType();
-            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-              tdescTy = storeNdOp.getTensorDescType();
-            }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              if (auto inst_data = layout.getInstData())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
-          return std::nullopt;
-        });
+    options.setNativeShapeFn([&](Operation *op)
+                                 -> std::optional<SmallVector<int64_t>> {
+      if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+        xegpu::TensorDescType tdescTy;
+        if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+          tdescTy = createNdOp.getType();
+        } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+          tdescTy = loadNdOp.getTensorDescType();
+        } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+          tdescTy = storeNdOp.getTensorDescType();
+        }
+
+        if (auto layout = tdescTy.getLayoutAttr()) {
+          if (auto inst_data = layout.getInstData())
+            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                        inst_data.asArrayRef().end());
+        }
+      }
+
+      return std::nullopt;
+    });
 
     auto funcOp = getOperation();
     RewritePatternSet patterns(&getContext());
@@ -432,7 +434,8 @@ struct XeGPUUnrollPass final
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
-    RewritePatternSet &patterns, const mlir::vector::UnrollVectorOptions &options) {
+    RewritePatternSet &patterns,
+    const mlir::vector::UnrollVectorOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
-        patterns.getContext(), options);
+      patterns.getContext(), options);
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c82a280b67d91..bfdfdd5bc5e51 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -7,12 +7,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
 using namespace mlir::xegpu;
@@ -44,28 +44,27 @@ struct TestXeGPUUnrollingPatterns
 
   void runOnOperation() override {
     vector::UnrollVectorOptions options;
-    options.setNativeShapeFn(
-        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
-          if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-              tdescTy = createNdOp.getType();
-            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-              tdescTy = loadNdOp.getTensorDescType();
-            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-              tdescTy = storeNdOp.getTensorDescType();
-            }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              if (auto inst_data = layout.getInstData())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
-          return std::nullopt;
-        });
-
+    options.setNativeShapeFn([&](Operation *op)
+                                 -> std::optional<SmallVector<int64_t>> {
+      if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+        xegpu::TensorDescType tdescTy;
+        if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+          tdescTy = createNdOp.getType();
+        } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+          tdescTy = loadNdOp.getTensorDescType();
+        } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+          tdescTy = storeNdOp.getTensorDescType();
+        }
+
+        if (auto layout = tdescTy.getLayoutAttr()) {
+          if (auto inst_data = layout.getInstData())
+            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                        inst_data.asArrayRef().end());
+        }
+      }
+
+      return std::nullopt;
+    });
 
     MLIRContext *ctx = &getContext();
     RewritePatternSet patterns(ctx);
@@ -82,5 +81,5 @@ namespace test {
 void registerTestXeGPULowerings() {
   PassRegistration<TestXeGPUUnrollingPatterns>();
 }
-}
-}
\ No newline at end of file
+} // namespace test
+} // namespace mlir
\ No newline at end of file

>From 008dbc781cdf1819d8df78bb360f5d14f7a9c705 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 23 Apr 2025 22:29:45 +0000
Subject: [PATCH 08/12] add unit test

---
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 47 +++++++++++++++++--
 1 file changed, 44 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 825bd3ff9f042..2f5dcf3930628 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -1,15 +1,56 @@
 // RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
 
 gpu.module @test {
-  // CHECK-LABEL: test_create_nd_tdesc_vc_1
+
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[data:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xetile_blocking_inner_block__ = array<i64: 8, 16>, __xetile_blocking_unpack__}
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_load_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+
+  //-----
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: xegpu.store_nd {{.*}}  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %data = arith.constant dense<9.0> : vector<24x32xf32>
+    xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_createNd_loadNd_storeNd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
   //CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
   //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
-  //CHECK-COUNT-6: %[[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
+  //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
   //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
+  gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %data = arith.constant dense<9.0> : vector<24x32xf32>
     %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>

>From d077cb081668914776f029553459b0c8189f4597 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 24 Apr 2025 15:06:21 +0000
Subject: [PATCH 09/12] clean up

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   | 11 -----
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 42 -------------------
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  |  1 +
 mlir/test/lib/Dialect/XeGPU/CMakeLists.txt    |  1 +
 4 files changed, 2 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 007dd81d1dfac..3e81f2d0ed786 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -38,15 +38,4 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   ];
 }
 
-def XeGPUUnroll: Pass<"xegpu-unroll"> {
-  let summary = "Unroll operations into smaller shapes";
-  let description = [{
-    The pass unrolls operations into smaller shapes that can be distribute
-    to an SIMD instruction.
-  }];
-  let dependentDialects = [
-    "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
-}
-
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index f3c53e36f40dc..c86235717cca2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -389,48 +389,6 @@ struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
     return failure();
   }
 };
-
-struct XeGPUUnrollPass final
-    : public xegpu::impl::XeGPUUnrollBase<XeGPUUnrollPass> {
-  XeGPUUnrollPass() = default;
-  XeGPUUnrollPass(const XeGPUUnrollPass &pass) = default;
-
-  void runOnOperation() override {
-    vector::UnrollVectorOptions options;
-    options.setNativeShapeFn([&](Operation *op)
-                                 -> std::optional<SmallVector<int64_t>> {
-      if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
-        xegpu::TensorDescType tdescTy;
-        if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
-          tdescTy = createNdOp.getType();
-        } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
-          tdescTy = loadNdOp.getTensorDescType();
-        } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
-          tdescTy = storeNdOp.getTensorDescType();
-        }
-
-        if (auto layout = tdescTy.getLayoutAttr()) {
-          if (auto inst_data = layout.getInstData())
-            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                        inst_data.asArrayRef().end());
-        }
-      }
-
-      return std::nullopt;
-    });
-
-    auto funcOp = getOperation();
-    RewritePatternSet patterns(&getContext());
-    patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
-        patterns.getContext(), options);
-
-    // GreedyRewriteConfig config;
-    // config.fold = false;
-    // config.cseConstants = false;
-    (void)applyPatternsGreedily(funcOp, std::move(patterns));
-    return;
-  }
-};
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 2f5dcf3930628..41f5c35e801a1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -29,6 +29,7 @@ gpu.module @test {
   }
 
   //-----
+
   // CHECK-LABEL: test_store_nd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
index 9223f2860ce15..6d6a92323d018 100644
--- a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -12,4 +12,5 @@ mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC
   MLIRPass
   MLIRTransforms
   MLIRGPUDialect
+  MLIRXeGPUTransforms
 )
\ No newline at end of file

>From 0193a04ce106d5bfb99679ac4023b9ad5ff87c8a Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 28 Apr 2025 18:53:44 +0000
Subject: [PATCH 10/12] stage

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td |  2 ++
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 15 +++++++++++++++
 2 files changed, 17 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b61d6fb..c3edeb5983788 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,6 +38,8 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
+      static constexpr const char *operandLayoutNamePrefix = "layout_operand_";
+      static constexpr const char *resultLayoutNamePrefix = "layout_result_";
       /// Checks if the given shape can be evenly distributed based on the layout
       /// and data factors provided by the LayoutAttr.
       static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c86235717cca2..302f940f5a63a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -378,7 +378,22 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::DpasOp op,
                                 PatternRewriter &rewriter) const override {
+
+    auto loc = op.getLoc();
+
+    // a vector of 3 elements should be returned, representing M, K, N respectively.
+    auto maybeTargetShape = getTargetShape(options, op);
+    if (!maybeTargetShape || maybeTargetShape->size() != 3)
+      return failure();
+    auto M = (*maybeTargetShape)[0];
+    auto K = (*maybeTargetShape)[1];
+    auto N = (*maybeTargetShape)[2];
+
+    llvm::dbgs() << "\nM: " << M << ", K: " << K << ", N: " << N << "\n";
+
     return failure();
+
+
   }
 };
 

>From 456465e49790acae7539f2784c257991ec1abcbe Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 29 Apr 2025 19:12:01 +0000
Subject: [PATCH 11/12] add dpas pattern and unit test

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 83 +++++++++++++++----
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 13 +++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  4 +
 3 files changed, 83 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 302f940f5a63a..20573421566e0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -138,11 +138,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
     return computeShapeRatio(shape, subShape);
   }
 
-  bool isUnrollable(Attribute attr) const {
-    auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
-    return layout && layout.isSgLayout() && layout.getInstData() != nullptr;
-  }
-
   xegpu::LayoutAttr getLaneLayoutAttr(Attribute attr) const {
     auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
     if (!layout || layout.getLaneLayout() == nullptr)
@@ -166,9 +161,6 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
     auto shape = tdescTy.getShape();
     auto layout = tdescTy.getLayout();
 
-    if (!isUnrollable(layout))
-      return failure();
-
     auto maybeTargetShape = getTargetShape(options, op);
     if (!maybeTargetShape)
       return failure();
@@ -242,9 +234,6 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     auto tdescTy = op.getTensorDescType();
     auto layout = tdescTy.getLayout();
 
-    if (!isUnrollable(layout))
-      return failure();
-
     auto maybeTargetShape = getTargetShape(options, op);
     if (!maybeTargetShape)
       return failure();
@@ -290,9 +279,6 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     auto tdescTy = op.getTensorDescType();
     auto layout = tdescTy.getLayout();
 
-    if (!isUnrollable(layout))
-      return failure();
-
     auto maybeTargetShape = getTargetShape(options, op);
     if (!maybeTargetShape)
       return failure();
@@ -389,11 +375,74 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
     auto K = (*maybeTargetShape)[1];
     auto N = (*maybeTargetShape)[2];
 
-    llvm::dbgs() << "\nM: " << M << ", K: " << K << ", N: " << N << "\n";
+    int64_t aBlockSize[2] = {M, K};
+    int64_t bBlockSize[2] = {K, N};
+    int64_t cBlockSize[2] = {M, N};
+
+    auto pack = [&](TypedValue<VectorType> val,
+                    llvm::ArrayRef<int64_t> blockSize) {
+      VectorType type = val.getType();
+      auto maybeGrids = computeShapeRatio(type.getShape(), blockSize);
+      assert(maybeGrids && "Expecting grids to be computed.");
+      auto grids = *maybeGrids;
+      auto numNewOps = computeProduct(grids);
+      if (numNewOps == 1)
+        return llvm::SmallVector<Value>({val});
+      auto newVecTy = type.cloneWith(blockSize, type.getElementType());
+      llvm::SmallVector<Type> convertedTypes(numNewOps, newVecTy);
+      auto values = addPackOp(val, convertedTypes, blockSize, loc, rewriter);
+      return llvm::to_vector(values);
+    };
 
-    return failure();
+    auto a = op.getLhs();
+    auto b = op.getRhs();
+    auto c = op.getAcc();
 
+    auto aShape = a.getType().getShape();
+    auto bShape = b.getType().getShape();
 
+    llvm::SmallVector<Value> aVals, bVals, cVals;
+    aVals = pack(a, aBlockSize);
+    bVals = pack(b, bBlockSize);
+
+    if (c)
+      cVals = pack(c, cBlockSize);
+
+    // Vals are empty due to invalid blocking size, or with size 1 due to
+    // the original shape is the same with the blocking size. The op will
+    // be skipped if every operand got an invalid blocking size or the
+    // original shape is the same with the blocking size.
+    if (aVals.size() <= 1 && bVals.size() <= 1 && cVals.size() <= 1)
+      return failure();
+
+    auto resultTy = op.getResult().getType();
+    auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
+
+    auto mIters = aShape[0] / M;
+    auto kIters = aShape[1] / K;
+    auto nIters = bShape[1] / N;
+
+    SmallVector<Value> newOps;
+    for (int64_t i = 0; i < mIters; i++) {
+      for (int64_t j = 0; j < nIters; j++) {
+        Value tmpC;
+        if (c)
+          tmpC = cVals[i * nIters + j]; // init with acc
+        for (int64_t k = 0; k < kIters; k++) {
+          auto aVec = aVals[i * kIters + k];
+          auto bVec = bVals[k * nIters + j];
+          llvm::SmallVector<Value> operands({aVec, bVec});
+          if (tmpC)
+            operands.push_back(tmpC);
+          tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
+                                                op->getAttrs());
+        }
+        newOps.push_back(tmpC);
+      }
+    }
+    auto castOp = addUnpackOp(newOps, resultTy, cBlockSize, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
   }
 };
 
@@ -409,6 +458,6 @@ struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns,
     const mlir::vector::UnrollVectorOptions &options) {
-  patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
+  patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
       patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 41f5c35e801a1..126b7e67c06a1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -60,4 +60,17 @@ gpu.module @test {
     gpu.return
   }
 
+  //-----
+
+  // CHECK-LABEL: test_dpas
+  // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
+  //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
+  //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
+  //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
+  //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+  gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+    %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
+    gpu.return %c : vector<32x32xf32>
+  }
+
 }
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index bfdfdd5bc5e51..ec28a57a24292 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -63,6 +63,10 @@ struct TestXeGPUUnrollingPatterns
         }
       }
 
+      if (isa<xegpu::DpasOp>(op)) {
+        return SmallVector<int64_t>{8, 16, 16};
+      }
+
       return std::nullopt;
     });
 

>From 906d699998b27f8f2d1c127a542d1d765afba41e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 29 Apr 2025 20:32:15 +0000
Subject: [PATCH 12/12] refactor

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 62 +++++++++++--------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  3 +-
 2 files changed, 39 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 20573421566e0..3d55ab9de2d92 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -133,21 +133,43 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   computeGrids(llvm::ArrayRef<int64_t> shape,
                llvm::ArrayRef<int64_t> subShape) const {
     // if the shape == subshape, we don't need to unroll.
-    if (shape == subShape)
+    if (shape == subShape) {
+      LDBG("shape == subshape, no unroll");
       return std::nullopt;
+    }
     return computeShapeRatio(shape, subShape);
   }
 
+  // copy the layout attribte and drops the inst_data field.
   xegpu::LayoutAttr getLaneLayoutAttr(Attribute attr) const {
     auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
     if (!layout || layout.getLaneLayout() == nullptr)
       return xegpu::LayoutAttr();
-    return xegpu::LayoutAttr::get(
-        layout.getContext(), nullptr /* sg_layout */, nullptr /* sg_data */,
+    return xegpu::LayoutAttr::get(layout.getContext(), nullptr /* sg_layout */, nullptr /* sg_data */,
         nullptr /* inst_data */, layout.getLaneLayout(), layout.getLaneData(),
         layout.getOrder());
+  };
+
+  std::optional<SmallVector<Type>> convertType(ShapedType type, llvm::ArrayRef<int64_t> blockSize) const {
+    auto elemTy = type.getElementType();
+    auto maybeGrids = computeGrids(type.getShape(), blockSize);
+
+    if (!maybeGrids)
+      return std::nullopt;
+
+    Type newTy;
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+      auto ctx = tdescTy.getContext();
+      auto encoding = tdescTy.getEncoding();
+      auto layout = tdescTy.getLayout();
+      newTy = xegpu::TensorDescType::get(ctx, blockSize, elemTy, encoding, getLaneLayoutAttr(layout));
+    } else {
+      newTy = type.clone(blockSize, elemTy);
+    }
+    return llvm::SmallVector<Type>(computeProduct(*maybeGrids), newTy);
   }
 
+
   vector::UnrollVectorOptions options;
 };
 
@@ -171,6 +193,10 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
       return failure();
     auto grids = *maybeGrids;
 
+    // TODO: enable scattered version later
+    if (tdescTy.isScattered())
+      return failure();
+
     auto encoding = tdescTy.getEncoding();
     auto newLayout = getLaneLayoutAttr(layout);
     auto newTdescTy = xegpu::TensorDescType::get(
@@ -229,10 +255,8 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
                                 PatternRewriter &rewriter) const override {
 
     auto loc = op.getLoc();
-    auto ctx = op.getContext();
     auto valueTy = op.getType();
     auto tdescTy = op.getTensorDescType();
-    auto layout = tdescTy.getLayout();
 
     auto maybeTargetShape = getTargetShape(options, op);
     if (!maybeTargetShape)
@@ -246,13 +270,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
 
     auto elemTy = tdescTy.getElementType();
     auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
-    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy,
-                                                 tdescTy.getEncoding(),
-                                                 getLaneLayoutAttr(layout));
 
-    auto numNewOps = computeProduct(grids);
-    llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
-    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
+    auto convertedTdescTypes = convertType(tdescTy, targetShape);
+    auto convertedTdescs = addPackOp(op.getTensorDesc(), *convertedTdescTypes,
                                      targetShape, loc, rewriter);
 
     llvm::SmallVector<Value> newOps;
@@ -274,10 +294,8 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
   LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    auto ctx = op.getContext();
     auto valueTy = op.getValueType();
     auto tdescTy = op.getTensorDescType();
-    auto layout = tdescTy.getLayout();
 
     auto maybeTargetShape = getTargetShape(options, op);
     if (!maybeTargetShape)
@@ -289,18 +307,12 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
       return failure();
     auto grids = *maybeGrids;
 
-    auto elemTy = tdescTy.getElementType();
-    auto newValueTy = valueTy.cloneWith(targetShape, elemTy);
-    auto newTdescTy = xegpu::TensorDescType::get(ctx, targetShape, elemTy,
-                                                 tdescTy.getEncoding(),
-                                                 getLaneLayoutAttr(layout));
-
-    auto numNewOps = computeProduct(grids);
-    llvm::SmallVector<Type> convertedValTypes(numNewOps, newValueTy);
-    llvm::SmallVector<Type> convertedTdescTypes(numNewOps, newTdescTy);
-    auto convertedValues =
-        addPackOp(op.getValue(), convertedValTypes, targetShape, loc, rewriter);
-    auto convertedTdescs = addPackOp(op.getTensorDesc(), convertedTdescTypes,
+    auto convertedValTypes = convertType(valueTy, targetShape);
+    auto convertedTdescTypes = convertType(tdescTy, targetShape);
+
+    auto convertedValues = addPackOp(op.getValue(), *convertedValTypes,
+                                     targetShape, loc, rewriter);
+    auto convertedTdescs = addPackOp(op.getTensorDesc(), *convertedTdescTypes,
                                      targetShape, loc, rewriter);
 
     for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index ec28a57a24292..d534e2cc3accd 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -57,7 +57,8 @@ struct TestXeGPUUnrollingPatterns
         }
 
         if (auto layout = tdescTy.getLayoutAttr()) {
-          if (auto inst_data = layout.getInstData())
+          auto inst_data = layout.getInstData();
+          if (inst_data && layout.isSgLayout())
             return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
                                         inst_data.asArrayRef().end());
         }



More information about the Mlir-commits mailing list