[Mlir-commits] [mlir] db42345 - [MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (#137010)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 12 07:16:25 PDT 2025
Author: Chao Chen
Date: 2025-05-12T09:16:21-05:00
New Revision: db42345dc660329e34fd119fc8edab74521f7c06
URL: https://github.com/llvm/llvm-project/commit/db42345dc660329e34fd119fc8edab74521f7c06
DIFF: https://github.com/llvm/llvm-project/commit/db42345dc660329e34fd119fc8edab74521f7c06.diff
LOG: [MLIR][XeGPU] Add unroll patterns for XeGPU (1/N) (#137010)
Similar to vector ops, XeGPU ops need to be unrolled into smaller shapes
such that they can be dispatched into a hardware instruction. This PR
marks the initial phase of a series dedicated to incorporating unroll
patterns for XeGPU operations. In this installment, we introduce
patterns for the following operations:
1. createNd
2. updateNd
3. prefetchNd
4. loadNd
5. storeNd
6. dpas
Added:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/CMakeLists.txt
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6d04ee5599a23..032ce5bc18334 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -303,7 +303,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}
-
}];
let assemblyFormat = "`<` struct(params) `>`";
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..09f9ce1e716c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -14,11 +14,67 @@ class RewritePatternSet;
namespace xegpu {
+/// Options to control the XeGPU unrolling. Its main purpose is to
+/// provide a way to customize the native shape of the operation.
+struct UnrollOptions {
+ /// Callback function that indicates whether vector unrolling should be
+ /// attempted on the operation.
+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
+ FilterConstraintFnType filterConstraint = nullptr;
+ UnrollOptions &setFilterConstraint(FilterConstraintFnType constraint) {
+ filterConstraint = std::move(constraint);
+ return *this;
+ }
+
+ /// Function that computes the target shape for unrolling. It returns an
+ /// optional vector of integers representing the shape. If it returns
+ /// `std::nullopt`, unrolling is aborted for the given operation.
+ using NativeShapeFnType =
+ std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
+ NativeShapeFnType nativeShape = nullptr;
+ UnrollOptions &setNativeShapeFn(NativeShapeFnType fn) {
+ nativeShape = std::move(fn);
+ return *this;
+ }
+
+ /// Function that converts a ShapedType (TensorDescType or VectorType)
+ /// into the unrolled type based on the tileShape. It returns a vector of
+ /// types representing the unrolled types for simplicity.
+ using UnrolledTypeFnType = std::function<SmallVector<Type>(
+ ShapedType type, ArrayRef<int64_t> tileShape)>;
+ UnrolledTypeFnType getUnrolledTypes = nullptr;
+ UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
+ getUnrolledTypes = std::move(fn);
+ return *this;
+ }
+};
+
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
+
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
+/// Users can control whether an operation to be unrolled or not, as well as
+/// its target shape via `options` structure. (via setting filterConstraint
+/// and nativeShape respectively, both of them are function refs taking `op` as
+/// input).
+/// An `op` is unrolled to the `targetShape` as follows, for each of its
+/// operands:
+/// 1. the unrolled type `unrolledType` and number of unrolled instances
+/// `numUnrolledInstances` are computed from the `targetShape`.
+/// 2. pack each operand. ExtractStridedSlice are created to break-up the
+/// vector operands. And BuiltinUnrealizedCastop are created to break-up
+/// the TensorDesc operands.
+/// 3. the original op is cloned `numUnrolledInstances` times, once for each
+/// result.
+/// 4. unpack the results. InsertStridedSlice are inserted for VectorType
+/// result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
+/// to re-assemble the slices into the original shape.
+void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
+ const UnrollOptions &options);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f2cfa50e102f8..c99e925a97633 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..892eb791c46e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUFoldAliasOps.cpp
XeGPUSubgroupDistribute.cpp
+ XeGPUUnroll.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
new file mode 100644
index 0000000000000..44d45dd2eaec0
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -0,0 +1,427 @@
+//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for unrolling XeGPU operations. It follows a
+// similar concept and design as vector unroll patterns, serving as a complement
+// to them.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUUNROLL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+template <typename SourceOp>
+struct UnrollPattern : public OpRewritePattern<SourceOp> {
+ UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<SourceOp>(context, benefit), options(options) {}
+
+protected:
+ /// Return the target shape for the given `op`. Return std::nullopt if the
+ /// op shouldn't be or cannot be unrolled.
+ std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
+ LDBG("");
+ LDBG("Get unroll shape for: " << *op);
+
+ if (options.filterConstraint && failed(options.filterConstraint(op))) {
+ LDBG("--no filter constraint -> BAIL");
+ return std::nullopt;
+ }
+
+ assert(options.nativeShape &&
+ "expects the native shape for native shape call back function.");
+ auto nativeShape = options.nativeShape(op);
+ return nativeShape;
+ }
+
+ SmallVector<Type> getUnrolledTypes(ShapedType type,
+ ArrayRef<int64_t> tileShape) const {
+ return options.getUnrolledTypes(type, tileShape);
+ }
+
+ /// Emulate the the unpack behavior using insert_strided_slice for VectorType
+ /// values and unrealized_conversion_cast for TensorDescType values.
+ Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
+ Location loc, PatternRewriter &rewriter) const {
+ if (auto vecTy = dyn_cast<VectorType>(destTy)) {
+ assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
+ "Expecting blockSize size to match the rank of destTy.");
+ auto shape = vecTy.getShape();
+ auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
+
+ Value result = rewriter.create<arith::ConstantOp>(
+ loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
+ for (auto [src, offsets] :
+ llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
+ SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, src, result, offsets, staticStrides);
+ }
+ return result;
+ }
+
+ if (isa<xegpu::TensorDescType>(destTy)) {
+ auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName),
+ rewriter.getUnitAttr());
+ auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
+ rewriter.getDenseI64ArrayAttr(blockSize));
+ auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+ loc, destTy, srcs, ArrayRef<NamedAttribute>({attr, blkAttr}));
+ return castOp.getResult(0);
+ }
+
+ llvm_unreachable("Unexpected destTy.");
+ return Value();
+ }
+
+ /// Emulate the the pack behavior using extract_strided_slice for VectorType
+ /// values and unrealized_conversion_cast for TensorDescType values.
+ SmallVector<Value> pack(Value src, TypeRange destTypes,
+ ArrayRef<int64_t> blockSize, Location loc,
+ PatternRewriter &rewriter) const {
+ if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
+ assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
+ "Expecting blockSize size to match the rank of src.");
+ auto shape = vecTy.getShape();
+ SmallVector<Value> results;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(shape, blockSize)) {
+ SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, src, offsets, blockSize, staticStrides);
+ results.push_back(slice);
+ }
+ return results;
+ }
+
+ if (isa<xegpu::TensorDescType>(src.getType())) {
+ auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName),
+ rewriter.getUnitAttr());
+ auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName),
+ rewriter.getDenseI64ArrayAttr(blockSize));
+ auto castOp = rewriter.create<UnrealizedConversionCastOp>(
+ loc, destTypes, src, ArrayRef<NamedAttribute>({attr, blkAttr}));
+ return castOp.getResults();
+ }
+
+ llvm_unreachable("Unexpected src type.");
+ return SmallVector<Value>();
+ }
+
+private:
+ const char *const packAttrName = "__xegpu_blocking_pack__";
+ const char *const unpackAttrName = "__xegpu_blocking_unpack__";
+ const char *const blockAttrName = "__xegpu_blocking_tile_shape__";
+
+ xegpu::UnrollOptions options;
+};
+
+struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
+ using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+ int64_t rank = tdescTy.getRank();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || llvm::equal(*targetShape, shape))
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+ auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+ std::optional<int64_t> maybeInt = getConstantIntValue(a);
+ if (maybeInt) {
+ return rewriter.create<arith::ConstantIndexOp>(loc, *maybeInt + b);
+ } else {
+ auto aV = llvm::cast<Value>(a);
+ auto bV = rewriter.create<arith::ConstantIndexOp>(loc, b);
+ return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+ }
+ };
+
+ SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+
+ // For n-D memrefs where n > rank, we need to handle the last `rank`
+ // dimensions only, and keep the first `n-rank` dimensions as is.
+ SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+ llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+ auto validIdxes =
+ llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+ SmallVector<Value> newOps;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(shape, *targetShape)) {
+
+ for (auto [idx, oldOff, offset] :
+ llvm::zip(validIdxes, oldOffsets, offsets))
+ mixedOffsets[idx] = addi(oldOff, offset);
+
+ auto newOp = rewriter.create<xegpu::CreateNdDescOp>(
+ loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(),
+ op.getMixedStrides());
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
+struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
+ using UnrollPattern<xegpu::UpdateNdOffsetOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || llvm::equal(*targetShape, shape))
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto t : convertedTdesc) {
+ auto newOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+ loc, t.getType(), t, op.getOffsets(), op.getConstOffsets());
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
+ using UnrollPattern<xegpu::PrefetchNdOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || llvm::equal(*targetShape, shape))
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchNdOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
+ using UnrollPattern<xegpu::LoadNdOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadNdOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = op.getType();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || llvm::equal(*targetShape, shape))
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto t : convertedTdescs) {
+ auto newOp =
+ rewriter.create<xegpu::LoadNdOp>(loc, newValueTy, t, op->getAttrs());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
+ using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = op.getValueType();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || llvm::equal(*targetShape, shape))
+ return failure();
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+ rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
+ using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::DpasOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+
+ // expecting every operands is a 2D Vector
+ if (llvm::any_of(op->getOperandTypes(), [&](Type type) {
+ auto vecTy = dyn_cast<VectorType>(type);
+ return !vecTy || vecTy.getRank() != 2;
+ }))
+ return failure();
+
+ // A vector of 3 elements should be returned, representing M, K, N
+ // respectively.
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape || targetShape->size() != 3)
+ return failure();
+ auto M = (*targetShape)[0];
+ auto K = (*targetShape)[1];
+ auto N = (*targetShape)[2];
+
+ int64_t aBlockSize[2] = {M, K};
+ int64_t bBlockSize[2] = {K, N};
+ int64_t cBlockSize[2] = {M, N};
+
+ auto packWrapper = [&](TypedValue<VectorType> val,
+ ArrayRef<int64_t> blockSize) {
+ VectorType type = val.getType();
+ std::optional<SmallVector<int64_t>> grids =
+ computeShapeRatio(type.getShape(), blockSize);
+ assert(grids && "Expecting grids to be computed.");
+ auto numNewOps = computeProduct(*grids);
+ if (numNewOps == 1)
+ return SmallVector<Value>({val});
+ VectorType newVecTy = type.cloneWith(blockSize, type.getElementType());
+ SmallVector<Type> convertedTypes(numNewOps, newVecTy);
+ SmallVector<Value> values =
+ pack(val, convertedTypes, blockSize, loc, rewriter);
+ return values;
+ };
+
+ auto a = op.getLhs();
+ auto b = op.getRhs();
+ auto c = op.getAcc();
+
+ auto aShape = a.getType().getShape();
+ auto bShape = b.getType().getShape();
+
+ SmallVector<Value> aVals, bVals, cVals;
+ aVals = packWrapper(a, aBlockSize);
+ bVals = packWrapper(b, bBlockSize);
+
+ if (c)
+ cVals = packWrapper(c, cBlockSize);
+
+ // Skip the operation if every operand has an invalid blocking size (empty)
+ // or if the original shape matches the blocking size (size == 1).
+ auto ranges = c ? SmallVector<ValueRange>({aVals, bVals, cVals})
+ : SmallVector<ValueRange>({aVals, bVals});
+ if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) ||
+ llvm::all_of(ranges, [](auto &v) { return v.size() == 1; }))
+ return failure();
+
+ VectorType resultTy = op.getResult().getType();
+ auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType());
+
+ int64_t mIters = aShape[0] / M;
+ int64_t kIters = aShape[1] / K;
+ int64_t nIters = bShape[1] / N;
+
+ SmallVector<Value> newOps;
+ for (int64_t i = 0; i < mIters; ++i) {
+ for (int64_t j = 0; j < nIters; ++j) {
+ Value tmpC;
+ if (c)
+ tmpC = cVals[i * nIters + j]; // init with acc
+
+ for (int64_t k = 0; k < kIters; ++k) {
+ Value aVec = aVals[i * kIters + k];
+ Value bVec = bVals[k * nIters + j];
+ SmallVector<Value> operands({aVec, bVec});
+ if (tmpC)
+ operands.push_back(tmpC);
+
+ tmpC = rewriter.create<xegpu::DpasOp>(loc, vecTy, operands,
+ op->getAttrs());
+ }
+ newOps.push_back(tmpC);
+ }
+ }
+ Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::xegpu::populateXeGPUUnrollPatterns(
+ RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
+ patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+ patterns.getContext(), options);
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
new file mode 100644
index 0000000000000..b911bb3bbdc1c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -0,0 +1,161 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @test {
+
+ // CHECK-LABEL: test_create_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+ gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_create_nd_tdesc_1d
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+ // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+ // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
+ gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_update_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
+ gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_update_nd_tdesc_1d
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+ // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
+ gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+ %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+ gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_prefetch_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
+ gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+ // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
+ gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+ xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+ //-----
+ // CHECK-LABEL: test_load_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+ gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ gpu.return %ld : vector<24x32xf32>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_load_nd_1d
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+ // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+ // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
+ gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+ %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
+ gpu.return %data : vector<64xf32>
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_store_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %data = arith.constant dense<9.0> : vector<24x32xf32>
+ xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_store_nd_1d
+ // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
+ // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
+ // CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+ %data = arith.constant dense<9.0> : vector<64xf32>
+ xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_createNd_loadNd_storeNd
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ //CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+ //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
+ //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
+ //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %data = arith.constant dense<9.0> : vector<24x32xf32>
+ %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ %add = arith.addf %data, %ld : vector<24x32xf32>
+ xegpu.store_nd %add, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+ //-----
+
+ // CHECK-LABEL: test_dpas
+ // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
+ //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
+ //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
+ //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
+ //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+ gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+ %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
+ gpu.return %c : vector<32x32xf32>
+ }
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 29fb4441a24fd..a8fd70e6397a5 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
add_subdirectory(Tosa)
add_subdirectory(Transform)
add_subdirectory(Vector)
+add_subdirectory(XeGPU)
diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
new file mode 100644
index 0000000000000..5236d8765eac8
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_dialect_library(MLIRXeGPUTestPasses
+ TestXeGPUTransforms.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+)
+
+mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC
+ MLIRAffineUtils
+ MLIRIR
+ MLIRMemRefDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+ MLIRGPUDialect
+ MLIRXeGPUTransforms
+)
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
new file mode 100644
index 0000000000000..eaa3b988cad82
--- /dev/null
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -0,0 +1,124 @@
+//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::xegpu;
+
+namespace {
+
+struct TestXeGPUUnrollingPatterns
+ : public PassWrapper<TestXeGPUUnrollingPatterns,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
+
+ StringRef getArgument() const final {
+ return "test-xegpu-unrolling-patterns";
+ }
+
+ StringRef getDescription() const final {
+ return "Test lowering patterns to unroll ops in the xegpu dialect";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ }
+
+ TestXeGPUUnrollingPatterns() = default;
+ TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
+ : PassWrapper(pass) {}
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ xegpu::UnrollOptions options;
+ options.setNativeShapeFn(
+ [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
+ tdescTy = createNdOp.getType();
+ } else if (auto updateNdOp =
+ dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ tdescTy = updateNdOp.getTensorDescType();
+ } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
+ tdescTy = prefetchNdOp.getTensorDescType();
+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
+ tdescTy = loadNdOp.getTensorDescType();
+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
+ tdescTy = storeNdOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
+ if (isa<xegpu::DpasOp>(op))
+ return SmallVector<int64_t>{8, 16, 16};
+
+ return std::nullopt;
+ });
+
+ options.setUnrolledTypesFn(
+ [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+ Type elemTy = type.getElementType();
+ Type newTy;
+
+ // TensorDescType needs to drop the inst_data field in the layout
+ // attribute
+ if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+ Attribute encoding = tdescTy.getEncoding();
+ auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
+ tdescTy.getLayout());
+ if (layout) {
+ if (layout.getLaneLayout() == nullptr)
+ layout = xegpu::LayoutAttr();
+ else
+ layout = layout.dropInstData();
+ }
+ newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+ layout);
+ } else {
+ newTy = type.clone(tileShape, elemTy);
+ }
+
+ std::optional<SmallVector<int64_t>> ratio =
+ computeShapeRatio(type.getShape(), tileShape);
+ assert(ratio && "Expecting the ratio to be valid.");
+ return SmallVector<Type>(computeProduct(*ratio), newTy);
+ });
+
+ RewritePatternSet patterns(ctx);
+
+ populateXeGPUUnrollPatterns(patterns, options);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestXeGPULowerings() {
+ PassRegistration<TestXeGPUUnrollingPatterns>();
+}
+} // namespace test
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index a5a442909fc6d..3220dca282eac 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -46,6 +46,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTilingInterfaceTestPasses
MLIRTosaTestPasses
MLIRVectorTestPasses
+ MLIRXeGPUTestPasses
MLIRTestVectorToSPIRV
MLIRLLVMTestPasses
)
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 344576a44ca41..cdcf59b2add13 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -158,6 +158,7 @@ void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestVulkanRunnerPipeline();
void registerTestWrittenToPass();
+void registerTestXeGPULowerings();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
void registerTestDialectConversionPasses();
void registerTestPDLByteCodePass();
@@ -301,6 +302,7 @@ void registerTestPasses() {
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestVulkanRunnerPipeline();
mlir::test::registerTestWrittenToPass();
+ mlir::test::registerTestXeGPULowerings();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
mlir::test::registerTestDialectConversionPasses();
mlir::test::registerTestPDLByteCodePass();
More information about the Mlir-commits
mailing list