[Mlir-commits] [mlir] [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (PR #160323)
Dmitry Chigarev
llvmlistbot at llvm.org
Wed Sep 24 03:09:50 PDT 2025
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/160323
>From f7eee8847ebe967c170361eae93429f9ee339451 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:37:58 +0000
Subject: [PATCH 1/3] [mlir][XeGPU][XeGPUUnroll] Support new syntax with
offsets moved to load_nd/store_nd/prefetch_nd
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 177 +++++++++++++-----
...xegpu-unroll-patterns-no-desc-offsets.mlir | 61 ++++++
2 files changed, 186 insertions(+), 52 deletions(-)
create mode 100644 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29c9fcdfebcdb..cad7436f23762 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -121,54 +121,81 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
xegpu::UnrollOptions options;
};
+// Generic helper function for unrolling operations with offsets.
+//
+// Iterates over tile offsets within the tensor descriptor shape and calls
+// the provided createOp function for each computed offset. This is used by
+// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
+// have explicit offsets that need to be adjusted for each unrolled tile.
+SmallVector<Value> computeUnrolledOffsets(
+ SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
+ ArrayRef<int64_t> targetShape,
+ const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
+ Location loc, PatternRewriter &rewriter) {
+ int64_t rank = tdescTy.getRank();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+ std::optional<int64_t> maybeInt = getConstantIntValue(a);
+ if (maybeInt) {
+ return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
+ } else {
+ auto aV = llvm::cast<Value>(a);
+ auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
+ return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+ }
+ };
+
+ SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+ llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+ auto validIdxes =
+ llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+ SmallVector<Value> newOps;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(shape, targetShape)) {
+
+ for (auto [idx, oldOff, offset] :
+ llvm::zip(validIdxes, oldOffsets, offsets))
+ mixedOffsets[idx] = addi(oldOff, offset);
+
+ auto newOp = createOp(mixedOffsets);
+ newOps.push_back(newOp);
+ }
+ return newOps;
+}
+
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
- int64_t rank = tdescTy.getRank();
- ArrayRef<int64_t> shape = tdescTy.getShape();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
- auto addi = [&](OpFoldResult a, int64_t b) -> Value {
- std::optional<int64_t> maybeInt = getConstantIntValue(a);
- if (maybeInt) {
- return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
- } else {
- auto aV = llvm::cast<Value>(a);
- auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
- return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
- }
- };
-
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
-
- // For n-D memrefs where n > rank, we need to handle the last `rank`
- // dimensions only, and keep the first `n-rank` dimensions as is.
- SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
- llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
- auto validIdxes =
- llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
-
SmallVector<Value> newOps;
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(shape, *targetShape)) {
-
- for (auto [idx, oldOff, offset] :
- llvm::zip(validIdxes, oldOffsets, offsets))
- mixedOffsets[idx] = addi(oldOff, offset);
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+ bool hasOffsets = op.getMixedOffsets().size() != 0;
+ if (!hasOffsets) {
auto newOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
- op.getMixedSizes(), op.getMixedStrides());
+ rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
newOps.push_back(newOp);
+ } else {
+ auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
+ return xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), offsets,
+ op.getMixedSizes(), op.getMixedStrides());
+ };
+
+ newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+ *targetShape, createOp, loc, rewriter);
}
+
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
@@ -216,17 +243,33 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
+
+ if (hasOffsets)
+ convertedTdescTypes.resize(1);
+
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- for (auto t : convertedTdesc)
- xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
- op->getAttrs());
+ if (!hasOffsets) {
+ for (auto t : convertedTdesc)
+ xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
+ op->getAttrs());
+ } else {
+ auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
+ xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ // return dummy Value to satisfy function's signature
+ return nullptr;
+ };
+
+ computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+ createPrefetch, loc, rewriter);
+ }
rewriter.eraseOp(op);
return success();
@@ -247,26 +290,39 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
+
+ if (hasOffsets)
+ convertedTdescTypes.resize(1);
+
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
SmallVector<Value> newOps;
- for (auto t : convertedTdescs) {
- auto newOp =
- xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
- newOps.push_back(newOp);
+
+ if (!hasOffsets) {
+ for (auto t : convertedTdescs) {
+ auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
+ op->getAttrs());
+ newOps.push_back(newOp);
+ }
+ } else {
+ auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
+ return xegpu::LoadNdOp::create(
+ rewriter, loc, newValueTy, convertedTdescs[0], offsets,
+ op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ };
+ newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+ *targetShape, createLoad, loc, rewriter);
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
rewriter.replaceOp(op, castOp);
return success();
}
@@ -285,22 +341,39 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ if (hasOffsets)
+ convertedTdescTypes.resize(1);
+
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
- xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ if (!hasOffsets) {
+ for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+ xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ } else {
+ size_t valueIndex = 0;
+ auto createStore = [&](SmallVector<OpFoldResult> offsets) {
+ xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
+ convertedTdescs[0], offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ // return dummy Value to satisfy function's signature
+ return nullptr;
+ };
+
+ computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+ createStore, loc, rewriter);
+ }
rewriter.eraseOp(op);
return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
new file mode 100644
index 0000000000000..f28e82a2a4c76
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @xevm_test {
+
+ // CHECK-LABEL: create_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
+ // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+ gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ }
+
+//-----
+ // CHECK-LABEL: load_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+ // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+ gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ gpu.return %ld : vector<24x32xf32>
+ }
+
+//-----
+ // CHECK-LABEL: load_nd_store_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+ //CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+ //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ //CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+//-----
+ // CHECK-LABEL: prefetch_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: load_nd_offsets_at_both_places
+ // CHECK-COUNT-2: builtin.unrealized_conversion_cast
+ gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ gpu.return %ld : vector<24x32xf32>
+ }
+}
\ No newline at end of file
>From f45f04735e0db6a662154a15b28e82682b2c6d86 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 23 Sep 2025 14:50:24 +0000
Subject: [PATCH 2/3] fix formatting
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 17 +++++++++++++----
.../xegpu-unroll-patterns-no-desc-offsets.mlir | 2 +-
2 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index cad7436f23762..80d1cb12dff80 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -195,7 +195,6 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createOp, loc, rewriter);
}
-
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
@@ -248,8 +247,11 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
- if (hasOffsets)
+ if (hasOffsets) {
+ // only need one tdesc, tile offsets will be computed
+ // at the operation level
convertedTdescTypes.resize(1);
+ }
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -298,8 +300,11 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
- if (hasOffsets)
+ if (hasOffsets) {
+ // only need one tdesc, tile offsets will be computed
+ // at the operation level
convertedTdescTypes.resize(1);
+ }
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -323,6 +328,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
rewriter.replaceOp(op, castOp);
return success();
}
@@ -348,8 +354,11 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
- if (hasOffsets)
+ if (hasOffsets) {
+ // only need one tdesc, tile offsets will be computed
+ // at the operation level
convertedTdescTypes.resize(1);
+ }
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
index f28e82a2a4c76..cbfd991b5557e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -58,4 +58,4 @@ gpu.module @xevm_test {
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
}
-}
\ No newline at end of file
+}
>From 932346e77594318552ae60d64473dda041192888 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 24 Sep 2025 10:09:36 +0000
Subject: [PATCH 3/3] Modify 'unrolledTypefFn' to return one single type
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../Dialect/XeGPU/Transforms/Transforms.h | 8 +++--
.../XeGPU/Transforms/XeGPUBlocking.cpp | 5 ++-
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 35 +++++--------------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 5 ++-
4 files changed, 22 insertions(+), 31 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 44b81796b1313..b74c15e5b7ac1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,9 +9,9 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
-#include "mlir/IR/Operation.h"
#include <functional>
#include <optional>
@@ -47,9 +47,11 @@ struct UnrollOptions {
/// Function that converts a ShapedType (TensorDescType or VectorType)
/// into the unrolled type based on the tileShape. It returns a vector of
- /// types representing the unrolled types for simplicity.
+ /// types representing the unrolled types for simplicity. When
+ /// `returnSingleType` is true, it returns a vector containing only one single
+ /// unrolled type.
using UnrolledTypeFnType = std::function<SmallVector<Type>(
- ShapedType type, ArrayRef<int64_t> tileShape)>;
+ ShapedType type, ArrayRef<int64_t> tileShape, bool returnSingleType)>;
UnrolledTypeFnType getUnrolledTypes = nullptr;
UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
getUnrolledTypes = std::move(fn);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7efa4b9fbd934..36c498e8b849d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() {
options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
- options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
+ options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) {
Type elemTy = type.getElementType();
Type newTy;
@@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() {
newTy = type.clone(tileShape, elemTy);
}
+ if (returnSingleType)
+ return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "The shape of the type must be a multiple of tileShape.");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 80d1cb12dff80..f738effe46a72 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
}
SmallVector<Type> getUnrolledTypes(ShapedType type,
- ArrayRef<int64_t> tileShape) const {
- return options.getUnrolledTypes(type, tileShape);
+ ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) const {
+ return options.getUnrolledTypes(type, tileShape, returnSingleType);
}
/// Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -244,14 +245,8 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
-
- if (hasOffsets) {
- // only need one tdesc, tile offsets will be computed
- // at the operation level
- convertedTdescTypes.resize(1);
- }
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -297,14 +292,8 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
-
- if (hasOffsets) {
- // only need one tdesc, tile offsets will be computed
- // at the operation level
- convertedTdescTypes.resize(1);
- }
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
@@ -351,14 +340,8 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
-
- if (hasOffsets) {
- // only need one tdesc, tile offsets will be computed
- // at the operation level
- convertedTdescTypes.resize(1);
- }
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index e1ba45c60ac36..b2bdf3efc65f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -95,7 +95,8 @@ struct TestXeGPUUnrollingPatterns
});
options.setUnrolledTypesFn(
- [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+ [&](ShapedType type, ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) -> SmallVector<Type> {
Type elemTy = type.getElementType();
Type newTy;
@@ -137,6 +138,8 @@ struct TestXeGPUUnrollingPatterns
newTy = type.clone(tileShape, elemTy);
}
+ if (returnSingleType)
+ return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "Expecting the ratio to be valid.");
More information about the Mlir-commits
mailing list