[Mlir-commits] [mlir] db42345 - [MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (#137010)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 12 07:16:25 PDT 2025


Author: Chao Chen
Date: 2025-05-12T09:16:21-05:00
New Revision: db42345dc660329e34fd119fc8edab74521f7c06

URL: https://github.com/llvm/llvm-project/commit/db42345dc660329e34fd119fc8edab74521f7c06
DIFF: https://github.com/llvm/llvm-project/commit/db42345dc660329e34fd119fc8edab74521f7c06.diff

LOG: [MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (#137010)

Similar to vector ops, XeGPU ops need to be unrolled into smaller shapes
such that they can be dispatched into a hardware instruction. This PR
marks the initial phase of a series dedicated to incorporating unroll
patterns for XeGPU operations. In this installment, we introduce
patterns for the following operations:
1. createNd
2. updateNd
3. prefetchNd
4. loadNd
5. storeNd
6. dpas

Added: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
    mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6d04ee5599a23..032ce5bc18334 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -303,7 +303,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
                              getLaneLayout(), getLaneData(), getOrder());
     }
-
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..09f9ce1e716c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -14,11 +14,67 @@ class RewritePatternSet;
 
 namespace xegpu {
 
+/// Options to control the XeGPU unrolling. Its main purpose is to
+/// provide a way to customize the native shape of the operation.
+struct UnrollOptions {
+  /// Callback function that indicates whether vector unrolling should be
+  /// attempted on the operation.
+  using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+  FilterConstraintFnType filterConstraint = nullptr;
+  UnrollOptions &setFilterConstraint(FilterConstraintFnType constraint) {
+    filterConstraint = std::move(constraint);
+    return *this;
+  }
+
+  /// Function that computes the target shape for unrolling. It returns an
+  /// optional vector of integers representing the shape. If it returns
+  /// `std::nullopt`, unrolling is aborted for the given operation.
+  using NativeShapeFnType =
+      std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
+  NativeShapeFnType nativeShape = nullptr;
+  UnrollOptions &setNativeShapeFn(NativeShapeFnType fn) {
+    nativeShape = std::move(fn);
+    return *this;
+  }
+
+  /// Function that converts a ShapedType (TensorDescType or VectorType)
+  /// into the unrolled type based on the tileShape. It returns a vector of
+  /// types representing the unrolled types for simplicity.
+  using UnrolledTypeFnType = std::function<SmallVector<Type>(
+      ShapedType type, ArrayRef<int64_t> tileShape)>;
+  UnrolledTypeFnType getUnrolledTypes = nullptr;
+  UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
+    getUnrolledTypes = std::move(fn);
+    return *this;
+  }
+};
+
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
+/// Users can control whether an operation to be unrolled or not, as well as
+/// its target shape via `options` structure. (via setting filterConstraint
+/// and nativeShape respectively, both of them are function refs taking `op` as
+/// input).
+/// An `op` is unrolled to the `targetShape` as follows, for each of its
+/// operands:
+///   1. the unrolled type `unrolledType` and number of unrolled instances
+///   `numUnrolledInstances` are computed from the `targetShape`.
+///   2. pack each operand. ExtractStridedSlice are created to break-up the
+///   vector operands. And BuiltinUnrealizedCastop are created to break-up
+///    the TensorDesc operands.
+///   3. the original op is cloned `numUnrolledInstances` times, once for each
+///   result.
+///   4. unpack the results. InsertStridedSlice are inserted for VectorType
+///   result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
+///   to re-assemble the slices into the original shape.
+void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
+                                 const UnrollOptions &options);
+
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f2cfa50e102f8..c99e925a97633 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..892eb791c46e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
+  XeGPUUnroll.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
new file mode 100644
index 0000000000000..44d45dd2eaec0
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -0,0 +1,427 @@
+//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for unrolling XeGPU operations. It follows a
+// similar concept and design as vector unroll patterns, serving as a complement
+// to them.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUUNROLL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+template <typename SourceOp>
+struct UnrollPattern : public OpRewritePattern<SourceOp> {
+  UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
+                PatternBenefit benefit = 1)
+      : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
+
+protected:
+  /// Return the target shape for the given `op`. Return std::nullopt if the
+  /// op shouldn't be or cannot be unrolled.
+  std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
+    LDBG("");
+    LDBG("Get unroll shape for: " << *op);
+
+    if (options.filterConstraint && failed(options.filterConstraint(op))) {
+      LDBG("--no filter constraint -> BAIL");
+      return std::nullopt;
+    }
+
+    assert(options.nativeShape &&
+           "expects the native shape for native shape call back function.");
+    auto nativeShape = options.nativeShape(op);
+    return nativeShape;
+  }
+
+  SmallVector<Type> getUnrolledTypes(ShapedType type,
+                                     ArrayRef<int64_t> tileShape) const {
+    return options.getUnrolledTypes(type, tileShape);
+  }
+
+  /// Emulate the the unpack behavior using insert_strided_slice for VectorType
+  /// values and unrealized_conversion_cast for TensorDescType values.
+  Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
+               Location loc, PatternRewriter &rewriter) const {
+    if (auto vecTy = dyn_cast<VectorType>(destTy)) {
+      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
+             "Expecting blockSize size to match the rank of destTy.");
+      auto shape = vecTy.getShape();
+      auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
+
+      Value result = rewriter.create<arith::ConstantOp>(
+          loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
+      for (auto [src, offsets] :
+           llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
+        SmallVector<int64_t> staticStrides(offsets.size(), 1);
+        result = rewriter.create<vector::InsertStridedSliceOp>(
+            loc, src, result, offsets, staticStrides);
+      }
+      return result;
+    }
+
+    if (isa<xegpu::TensorDescType>(destTy)) {
+      auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
+                                 rewriter.getUnitAttr());
+      auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
+                                    rewriter.getDenseI64ArrayAttr(blockSize));
+      auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+          loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
+      return castOp.getResult(0);
+    }
+
+    llvm_unreachable("Unexpected destTy.");
+    return Value();
+  }
+
+  /// Emulate the the pack behavior using extract_strided_slice for VectorType
+  /// values and unrealized_conversion_cast for TensorDescType values.
+  SmallVector<Value> pack(Value src, TypeRange destTypes,
+                          ArrayRef<int64_t> blockSize, Location loc,
+                          PatternRewriter &rewriter) const {
+    if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
+      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
+             "Expecting blockSize size to match the rank of src.");
+      auto shape = vecTy.getShape();
+      SmallVector<Value> results;
+      for (SmallVector<int64_t> offsets :
+           StaticTileOffsetRange(shape, blockSize)) {
+        SmallVector<int64_t> staticStrides(offsets.size(), 1);
+        auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, src, offsets, blockSize, staticStrides);
+        results.push_back(slice);
+      }
+      return results;
+    }
+
+    if (isa<xegpu::TensorDescType>(src.getType())) {
+      auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
+                                 rewriter.getUnitAttr());
+      auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
+                                    rewriter.getDenseI64ArrayAttr(blockSize));
+      auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+          loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
+      return castOp.getResults();
+    }
+
+    llvm_unreachable("Unexpected src type.");
+    return SmallVector<Value>();
+  }
+
+private:
+  const char *const packAttrName = "__xegpu_blocking_pack__";
+  const char *const unpackAttrName = "__xegpu_blocking_unpack__";
+  const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
+
+  xegpu::UnrollOptions options;
+};
+
+struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
+  using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+    int64_t rank = tdescTy.getRank();
+    ArrayRef<int64_t> shape = tdescTy.getShape();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || llvm::equal(*targetShape, shape))
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+      std::optional<int64_t> maybeInt = getConstantIntValue(a);
+      if (maybeInt) {
+        return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
+      } else {
+        auto aV = llvm::cast<Value>(a);
+        auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
+        return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+      }
+    };
+
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+
+    // For n-D memrefs where n > rank, we need to handle the last `rank`
+    // dimensions only, and keep the first `n-rank` dimensions as is.
+    SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+        llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+    auto validIdxes =
+        llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+    SmallVector<Value> newOps;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(shape, *targetShape)) {
+
+      for (auto [idx, oldOff, offset] :
+           llvm::zip(validIdxes, oldOffsets, offsets))
+        mixedOffsets[idx] = addi(oldOff, offset);
+
+      auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
+          op.getMixedStrides());
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
+struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
+  using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+    ArrayRef<int64_t> shape = tdescTy.getShape();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || llvm::equal(*targetShape, shape))
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto t : convertedTdesc) {
+      auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+          loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
+  using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+    ArrayRef<int64_t> shape = tdescTy.getShape();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || llvm::equal(*targetShape, shape))
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
+  using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getType();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+    ArrayRef<int64_t> shape = tdescTy.getShape();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || llvm::equal(*targetShape, shape))
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto t : convertedTdescs) {
+      auto newOp =
+          rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
+  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getValueType();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+    ArrayRef<int64_t> shape = tdescTy.getShape();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || llvm::equal(*targetShape, shape))
+      return failure();
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+      rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
+  using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::DpasOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    // expecting every operands is a 2D Vector
+    if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
+          auto vecTy = dyn_cast<VectorType>(type);
+          return !vecTy || vecTy.getRank() != 2;
+        }))
+      return failure();
+
+    // A vector of 3 elements should be returned, representing M, K, N
+    // respectively.
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape || targetShape->size() != 3)
+      return failure();
+    auto M = (*targetShape)[0];
+    auto K = (*targetShape)[1];
+    auto N = (*targetShape)[2];
+
+    int64_t aBlockSize[2] = {M, K};
+    int64_t bBlockSize[2] = {K, N};
+    int64_t cBlockSize[2] = {M, N};
+
+    auto packWrapper = [&](TypedValue<VectorType> val,
+                           ArrayRef<int64_t> blockSize) {
+      VectorType type = val.getType();
+      std::optional<SmallVector<int64_t>> grids =
+          computeShapeRatio(type.getShape(), blockSize);
+      assert(grids && "Expecting grids to be computed.");
+      auto numNewOps = computeProduct(*grids);
+      if (numNewOps == 1)
+        return SmallVector<Value>({val});
+      VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
+      SmallVector<Type> convertedTypes(numNewOps, newVecTy);
+      SmallVector<Value> values =
+          pack(val, convertedTypes, blockSize, loc, rewriter);
+      return values;
+    };
+
+    auto a = op.getLhs();
+    auto b = op.getRhs();
+    auto c = op.getAcc();
+
+    auto aShape = a.getType().getShape();
+    auto bShape = b.getType().getShape();
+
+    SmallVector<Value> aVals, bVals, cVals;
+    aVals = packWrapper(a, aBlockSize);
+    bVals = packWrapper(b, bBlockSize);
+
+    if (c)
+      cVals = packWrapper(c, cBlockSize);
+
+    // Skip the operation if every operand has an invalid blocking size (empty)
+    // or if the original shape matches the blocking size (size == 1).
+    auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
+                    : SmallVector<ValueRange>({aVals, bVals});
+    if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
+        llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
+      return failure();
+
+    VectorType resultTy = op.getResult().getType();
+    auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
+
+    int64_t mIters = aShape[0] / M;
+    int64_t kIters = aShape[1] / K;
+    int64_t nIters = bShape[1] / N;
+
+    SmallVector<Value> newOps;
+    for (int64_t i = 0; i < mIters; ++i) {
+      for (int64_t j = 0; j < nIters; ++j) {
+        Value tmpC;
+        if (c)
+          tmpC = cVals[i * nIters + j]; // init with acc
+
+        for (int64_t k = 0; k < kIters; ++k) {
+          Value aVec = aVals[i * kIters + k];
+          Value bVec = bVals[k * nIters + j];
+          SmallVector<Value> operands({aVec, bVec});
+          if (tmpC)
+            operands.push_back(tmpC);
+
+          tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
+                                                op->getAttrs());
+        }
+        newOps.push_back(tmpC);
+      }
+    }
+    Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::xegpu::populateXeGPUUnrollPatterns(
+    RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
+  patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+      patterns.getContext(), options);
+}

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
new file mode 100644
index 0000000000000..b911bb3bbdc1c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @test {
+
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+  // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_create_nd_tdesc_1d
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+  // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+  // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+  // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
+  gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_update_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_update_nd_tdesc_1d
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+  // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
+  gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+    %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+    %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+    gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+  // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
+  gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+    gpu.return
+  }
+
+  //-----
+  // CHECK-LABEL: test_load_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    gpu.return %ld : vector<24x32xf32>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_load_nd_1d
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+  // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+  // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
+  gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+    %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+    %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
+    gpu.return %data : vector<64xf32>
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK-COUNT-6: xegpu.store_nd {{.*}}  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %data = arith.constant dense<9.0> : vector<24x32xf32>
+    xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_store_nd_1d
+  // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+  // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+  // CHECK-COUNT-4: xegpu.store_nd {{.*}}  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+  gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+    %data = arith.constant dense<9.0> : vector<64xf32>
+    xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+    gpu.return
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_createNd_loadNd_storeNd
+  // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+  //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+  //CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+  //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
+  //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
+  //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    %data = arith.constant dense<9.0> : vector<24x32xf32>
+    %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+    %add = arith.addf %data, %ld : vector<24x32xf32>
+    xegpu.store_nd %add, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+    gpu.return
+  }
+
+  //-----
+
+  // CHECK-LABEL: test_dpas
+  // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
+  //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
+  //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
+  //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
+  //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+  gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+    %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
+    gpu.return %c : vector<32x32xf32>
+  }
+}

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd..a8fd70e6397a5 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)
+add_subdirectory(XeGPU)

diff  --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..5236d8765eac8
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIRXeGPUTestPasses
+  TestXeGPUTransforms.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+)
+
+mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC
+  MLIRAffineUtils
+  MLIRIR
+  MLIRMemRefDialect
+  MLIRXeGPUDialect
+  MLIRPass
+  MLIRTransforms
+  MLIRGPUDialect
+  MLIRXeGPUTransforms
+)

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 0000000000000..eaa3b988cad82
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,124 @@
+//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+
+namespace {
+
+struct TestXeGPUUnrollingPatterns
+    : public PassWrapper<TestXeGPUUnrollingPatterns,
+                         OperationPass<gpu::GPUModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
+
+  StringRef getArgument() const final {
+    return "test-xegpu-unrolling-patterns";
+  }
+
+  StringRef getDescription() const final {
+    return "Test lowering patterns to unroll ops in the xegpu dialect";
+  }
+
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<memref::MemRefDialect>();
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+  }
+
+  TestXeGPUUnrollingPatterns() = default;
+  TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
+      : PassWrapper(pass) {}
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+    xegpu::UnrollOptions options;
+    options.setNativeShapeFn(
+        [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+          if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+              tdescTy = createNdOp.getType();
+            } else if (auto updateNdOp =
+                           dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+              tdescTy = updateNdOp.getTensorDescType();
+            } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+              tdescTy = prefetchNdOp.getTensorDescType();
+            } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+              tdescTy = loadNdOp.getTensorDescType();
+            } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+              tdescTy = storeNdOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              auto inst_data = layout.getInstData();
+              if (inst_data && layout.isSgLayout())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
+          if (isa<xegpu::DpasOp>(op))
+            return SmallVector<int64_t>{8, 16, 16};
+
+          return std::nullopt;
+        });
+
+    options.setUnrolledTypesFn(
+        [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+          Type elemTy = type.getElementType();
+          Type newTy;
+
+          // TensorDescType needs to drop the inst_data field in the layout
+          // attribute
+          if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+            Attribute encoding = tdescTy.getEncoding();
+            auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
+                tdescTy.getLayout());
+            if (layout) {
+              if (layout.getLaneLayout() == nullptr)
+                layout = xegpu::LayoutAttr();
+              else
+                layout = layout.dropInstData();
+            }
+            newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+                                               layout);
+          } else {
+            newTy = type.clone(tileShape, elemTy);
+          }
+
+          std::optional<SmallVector<int64_t>> ratio =
+              computeShapeRatio(type.getShape(), tileShape);
+          assert(ratio && "Expecting the ratio to be valid.");
+          return SmallVector<Type>(computeProduct(*ratio), newTy);
+        });
+
+    RewritePatternSet patterns(ctx);
+
+    populateXeGPUUnrollPatterns(patterns, options);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPULowerings() {
+  PassRegistration<TestXeGPUUnrollingPatterns>();
+}
+} // namespace test
+} // namespace mlir
\ No newline at end of file

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index a5a442909fc6d..3220dca282eac 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -46,6 +46,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTilingInterfaceTestPasses
     MLIRTosaTestPasses
     MLIRVectorTestPasses
+    MLIRXeGPUTestPasses
     MLIRTestVectorToSPIRV
     MLIRLLVMTestPasses
     )

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 344576a44ca41..cdcf59b2add13 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -158,6 +158,7 @@ void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestVulkanRunnerPipeline();
 void registerTestWrittenToPass();
+void registerTestXeGPULowerings();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
 void registerTestDialectConversionPasses();
 void registerTestPDLByteCodePass();
@@ -301,6 +302,7 @@ void registerTestPasses() {
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestVulkanRunnerPipeline();
   mlir::test::registerTestWrittenToPass();
+  mlir::test::registerTestXeGPULowerings();
 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
   mlir::test::registerTestDialectConversionPasses();
   mlir::test::registerTestPDLByteCodePass();


        


More information about the Mlir-commits mailing list