[Mlir-commits] [mlir] 04258fe - [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (#160323)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 25 11:31:22 PDT 2025
Author: Dmitry Chigarev
Date: 2025-09-25T11:31:17-07:00
New Revision: 04258fe3b15c9ecf78848c9b1470e1048844989e
URL: https://github.com/llvm/llvm-project/commit/04258fe3b15c9ecf78848c9b1470e1048844989e
DIFF: https://github.com/llvm/llvm-project/commit/04258fe3b15c9ecf78848c9b1470e1048844989e.diff
LOG: [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd (#160323)
Adds support for new syntax in XeGPUUnroll for:
1. `create_nd_desc` without offsets
2. `load_nd` with offsets
3. `store_nd` with offsets
4. `prefetch_nd` with offsets
`create_nd_desc with offsets` + `load_nd with offsets` won't be lowered
correctly. In this case the IR would still have two unrealized
conversions that will fail later in the pipeline.
The offsets computation for the unrolled tile is now moved from
descriptors to load/store/prefetch operations. The resulted IR now has
one single descriptor that is being iterated in load/store/prefetch ops.
<details><summary>old/new behavior examples</summary>
```mlir
// before unroll pass:
gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
}
// after unroll pass (offsets in create_nd_desc):
gpu.func @create_nd_tdesc2(%arg0: memref<256x318xf32>) -> vector<24x32xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<24x32xf32>
%c24 = arith.constant 24 : index
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
// create 6 descriptors for each tile
%0 = xegpu.create_nd_tdesc %arg0[%c8, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %arg0[%c8, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %arg0[%c16, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %arg0[%c16, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%4 = xegpu.create_nd_tdesc %arg0[%c24, %c16] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%5 = xegpu.create_nd_tdesc %arg0[%c24, %c32] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
%6 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%7 = xegpu.load_nd %1 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%8 = xegpu.load_nd %2 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%9 = xegpu.load_nd %3 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%10 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%11 = xegpu.load_nd %5 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
...
}
// after unroll pass (offsets in load_nd):
gpu.func @load_nd(%arg0: memref<256x318xf32>) -> vector<24x32xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<24x32xf32>
%c24 = arith.constant 24 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
// create only one descriptor with proper tile shape
%0 = xegpu.create_nd_tdesc %arg0 : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
// compute tile offsets at the operation (using only one descriptor)
%1 = xegpu.load_nd %0[%c8, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%2 = xegpu.load_nd %0[%c8, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%3 = xegpu.load_nd %0[%c16, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%4 = xegpu.load_nd %0[%c16, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%5 = xegpu.load_nd %0[%c24, %c16] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
%6 = xegpu.load_nd %0[%c24, %c32] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
...
}
```
</details>
---------
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
Added:
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 44b81796b1313..b74c15e5b7ac1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,9 +9,9 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
+#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
-#include "mlir/IR/Operation.h"
#include <functional>
#include <optional>
@@ -47,9 +47,11 @@ struct UnrollOptions {
/// Function that converts a ShapedType (TensorDescType or VectorType)
/// into the unrolled type based on the tileShape. It returns a vector of
- /// types representing the unrolled types for simplicity.
+ /// types representing the unrolled types for simplicity. When
+ /// `returnSingleType` is true, it returns a vector containing only one single
+ /// unrolled type.
using UnrolledTypeFnType = std::function<SmallVector<Type>(
- ShapedType type, ArrayRef<int64_t> tileShape)>;
+ ShapedType type, ArrayRef<int64_t> tileShape, bool returnSingleType)>;
UnrolledTypeFnType getUnrolledTypes = nullptr;
UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
getUnrolledTypes = std::move(fn);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7efa4b9fbd934..36c498e8b849d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() {
options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
- options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
+ options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) {
Type elemTy = type.getElementType();
Type newTy;
@@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() {
newTy = type.clone(tileShape, elemTy);
}
+ if (returnSingleType)
+ return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "The shape of the type must be a multiple of tileShape.");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index d7585fa5df8b3..a178d0fe4b0b0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
}
SmallVector<Type> getUnrolledTypes(ShapedType type,
- ArrayRef<int64_t> tileShape) const {
- return options.getUnrolledTypes(type, tileShape);
+ ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) const {
+ return options.getUnrolledTypes(type, tileShape, returnSingleType);
}
/// Emulate the the unpack behavior using insert_strided_slice for VectorType
@@ -121,53 +122,79 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
xegpu::UnrollOptions options;
};
+// Generic helper function for unrolling operations with offsets.
+//
+// Iterates over tile offsets within the tensor descriptor shape and calls
+// the provided createOp function for each computed offset. This is used by
+// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
+// have explicit offsets that need to be adjusted for each unrolled tile.
+SmallVector<Value> computeUnrolledOffsets(
+ SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
+ ArrayRef<int64_t> targetShape,
+ const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
+ Location loc, PatternRewriter &rewriter) {
+ int64_t rank = tdescTy.getRank();
+ ArrayRef<int64_t> shape = tdescTy.getShape();
+
+ auto addi = [&](OpFoldResult a, int64_t b) -> Value {
+ std::optional<int64_t> maybeInt = getConstantIntValue(a);
+ if (maybeInt) {
+ return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
+ } else {
+ auto aV = llvm::cast<Value>(a);
+ auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
+ return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
+ }
+ };
+
+ SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
+ llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
+ auto validIdxes =
+ llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
+
+ SmallVector<Value> newOps;
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(shape, targetShape)) {
+
+ for (auto [idx, oldOff, offset] :
+ llvm::zip(validIdxes, oldOffsets, offsets))
+ mixedOffsets[idx] = addi(oldOff, offset);
+
+ auto newOp = createOp(mixedOffsets);
+ newOps.push_back(newOp);
+ }
+ return newOps;
+}
+
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
- int64_t rank = tdescTy.getRank();
- ArrayRef<int64_t> shape = tdescTy.getShape();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
- auto addi = [&](OpFoldResult a, int64_t b) -> Value {
- std::optional<int64_t> maybeInt = getConstantIntValue(a);
- if (maybeInt) {
- return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
- } else {
- auto aV = llvm::cast<Value>(a);
- auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
- return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
- }
- };
-
- SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
-
- // For n-D memrefs where n > rank, we need to handle the last `rank`
- // dimensions only, and keep the first `n-rank` dimensions as is.
- SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
- llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
- auto validIdxes =
- llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());
-
SmallVector<Value> newOps;
- for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(shape, *targetShape)) {
-
- for (auto [idx, oldOff, offset] :
- llvm::zip(validIdxes, oldOffsets, offsets))
- mixedOffsets[idx] = addi(oldOff, offset);
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+ bool hasOffsets = op.getMixedOffsets().size() != 0;
+ if (!hasOffsets) {
auto newOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
- op.getMixedSizes(), op.getMixedStrides());
+ rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
newOps.push_back(newOp);
+ } else {
+ auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
+ return xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), offsets,
+ op.getMixedSizes(), op.getMixedStrides());
+ };
+
+ newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+ *targetShape, createOp, loc, rewriter);
}
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
@@ -216,17 +243,30 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
+
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- for (auto t : convertedTdesc)
- xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
- op->getAttrs());
+ if (!hasOffsets) {
+ for (auto t : convertedTdesc)
+ xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
+ op->getAttrs());
+ } else {
+ auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
+ xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ // return dummy Value to satisfy function's signature
+ return nullptr;
+ };
+
+ computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+ createPrefetch, loc, rewriter);
+ }
rewriter.eraseOp(op);
return success();
@@ -247,22 +287,33 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
+
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
SmallVector<Value> newOps;
- for (auto t : convertedTdescs) {
- auto newOp =
- xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
- newOps.push_back(newOp);
+
+ if (!hasOffsets) {
+ for (auto t : convertedTdescs) {
+ auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
+ op->getAttrs());
+ newOps.push_back(newOp);
+ }
+ } else {
+ auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
+ return xegpu::LoadNdOp::create(
+ rewriter, loc, newValueTy, convertedTdescs[0], offsets,
+ op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ };
+ newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
+ *targetShape, createLoad, loc, rewriter);
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -285,22 +336,36 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
return failure();
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
- if ((offsetSize != 0) || op.getConstOffsetsAttr())
- return failure();
+ bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
+ tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
- xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ if (!hasOffsets) {
+ for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
+ xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ } else {
+ size_t valueIndex = 0;
+ auto createStore = [&](SmallVector<OpFoldResult> offsets) {
+ xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
+ convertedTdescs[0], offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ // return dummy Value to satisfy function's signature
+ return nullptr;
+ };
+
+ computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
+ createStore, loc, rewriter);
+ }
rewriter.eraseOp(op);
return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
new file mode 100644
index 0000000000000..6eee5a544e3f8
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
@@ -0,0 +1,61 @@
+// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
+
+gpu.module @xevm_test {
+
+ // CHECK-LABEL: create_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
+ // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
+ // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
+ gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ }
+
+//-----
+ // CHECK-LABEL: load_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+ // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
+ gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ gpu.return %ld : vector<24x32xf32>
+ }
+
+//-----
+ // CHECK-LABEL: load_nd_store_nd
+ // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
+ // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+ // CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+//-----
+ // CHECK-LABEL: prefetch_nd_tdesc
+ // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
+ // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
+ gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: load_nd_offsets_at_both_places
+ // CHECK-COUNT-2: builtin.unrealized_conversion_cast
+ gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
+ %tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+ %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
+ gpu.return %ld : vector<24x32xf32>
+ }
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c83faea2e622c..094ef0a45b8d2 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -113,7 +113,8 @@ struct TestXeGPUUnrollingPatterns
});
options.setUnrolledTypesFn(
- [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+ [&](ShapedType type, ArrayRef<int64_t> tileShape,
+ bool returnSingleType = false) -> SmallVector<Type> {
Type elemTy = type.getElementType();
Type newTy;
@@ -155,6 +156,8 @@ struct TestXeGPUUnrollingPatterns
newTy = type.clone(tileShape, elemTy);
}
+ if (returnSingleType)
+ return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "Expecting the ratio to be valid.");
More information about the Mlir-commits
mailing list