[Mlir-commits] [mlir] [mlir][vector] Add `actualRank` output parameter to `createUnrollIterator()` (PR #94197)
Benjamin Maxwell
llvmlistbot at llvm.org
Mon Jun 3 02:16:14 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/94197
This provides an easy way of finding the actual rank the vector type will/can be unrolled to (which may be > the `targetRank`).
>From afba555b002973e8f9eae5b7746e1a93c5c7ed42 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 3 Jun 2024 09:06:12 +0000
Subject: [PATCH] [mlir][vector] Add `actualRank` output parameter to
`createUnrollIterator()`
This provides an easy way of finding the actual rank the vector type
will/can be unrolled to (which may be > the `targetRank`).
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 10 +++++++---
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 17 ++++++++++++++---
2 files changed, 21 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 9c83acc76e77a..571768dea8c16 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -86,20 +86,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
///
/// If no leading dimensions can be unrolled an empty optional will be returned.
///
+/// The actual rank the vector type can be unrolled to can be discovered by
+/// passing a pointer (to an int64_t) to the optional `actualRank` parameter.
+///
/// Examples:
///
/// For vType = vector<2x3x4> and targetRank = 1
///
/// The resulting iterator will yield:
-/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
+/// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] (actualRank = 1)
///
/// For vType = vector<3x[4]x5> and targetRank = 0
///
/// The scalable dimension blocks unrolling so the iterator yields only:
-/// [0], [1], [2]
+/// [0], [1], [2] (actualRank = 2)
///
std::optional<StaticTileOffsetRange>
-createUnrollIterator(VectorType vType, int64_t targetRank = 1);
+createUnrollIterator(VectorType vType, int64_t targetRank = 1,
+ int64_t *actualRank = nullptr);
/// A wrapper for getMixedSizes for vector.transfer_read and
/// vector.transfer_write Ops (for source and destination, respectively).
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 6727f3f461722..392758ec6565a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -285,9 +285,17 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
}
std::optional<StaticTileOffsetRange>
-vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
- if (vType.getRank() <= targetRank)
+vector::createUnrollIterator(VectorType vType, int64_t targetRank,
+ int64_t *actualRank) {
+ auto reportActualRank = [&](int64_t rank) {
+ if (actualRank)
+ *actualRank = rank;
+ };
+ auto vectorRank = vType.getRank();
+ if (vectorRank <= targetRank) {
+ reportActualRank(vectorRank);
return {};
+ }
// Attempt to unroll until targetRank or the first scalable dimension (which
// cannot be unrolled).
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
@@ -295,14 +303,17 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
auto it =
std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
auto firstScalableDim = it - scalableDimsToUnroll.begin();
- if (firstScalableDim == 0)
+ if (firstScalableDim == 0) {
+ reportActualRank(vectorRank);
return {};
+ }
// All scalable dimensions should be removed now.
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
"unexpected leading scalable dimension");
// Create an unroll iterator for leading dimensions.
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
+ reportActualRank(vectorRank - shapeToUnroll.size());
return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
}
More information about the Mlir-commits
mailing list