[Mlir-commits] [mlir] [mlir][vector] Add `actualRank` output parameter to `createUnrollIterator()` (PR #94197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 3 02:16:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This provides an easy way of finding the actual rank the vector type will/can be unrolled to (which may be > the `targetRank`).

---
Full diff: https://github.com/llvm/llvm-project/pull/94197.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+7-3) 
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+14-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 9c83acc76e77a..571768dea8c16 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -86,20 +86,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 ///
 /// If no leading dimensions can be unrolled an empty optional will be returned.
 ///
+/// The actual rank the vector type can be unrolled to can be discovered by
+/// passing a pointer (to an int64_t) to the optional `actualRank` parameter.
+///
 /// Examples:
 ///
 ///   For vType = vector<2x3x4> and targetRank = 1
 ///
 ///   The resulting iterator will yield:
-///     [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
+///     [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2] (actualRank = 1)
 ///
 ///   For vType = vector<3x[4]x5> and targetRank = 0
 ///
 ///   The scalable dimension blocks unrolling so the iterator yields only:
-///     [0], [1], [2]
+///     [0], [1], [2] (actualRank = 2)
 ///
 std::optional<StaticTileOffsetRange>
-createUnrollIterator(VectorType vType, int64_t targetRank = 1);
+createUnrollIterator(VectorType vType, int64_t targetRank = 1,
+                     int64_t *actualRank = nullptr);
 
 /// A wrapper for getMixedSizes for vector.transfer_read and
 /// vector.transfer_write Ops (for source and destination, respectively).
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 6727f3f461722..392758ec6565a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -285,9 +285,17 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
 }
 
 std::optional<StaticTileOffsetRange>
-vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
-  if (vType.getRank() <= targetRank)
+vector::createUnrollIterator(VectorType vType, int64_t targetRank,
+                             int64_t *actualRank) {
+  auto reportActualRank = [&](int64_t rank) {
+    if (actualRank)
+      *actualRank = rank;
+  };
+  auto vectorRank = vType.getRank();
+  if (vectorRank <= targetRank) {
+    reportActualRank(vectorRank);
     return {};
+  }
   // Attempt to unroll until targetRank or the first scalable dimension (which
   // cannot be unrolled).
   auto shapeToUnroll = vType.getShape().drop_back(targetRank);
@@ -295,14 +303,17 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   auto it =
       std::find(scalableDimsToUnroll.begin(), scalableDimsToUnroll.end(), true);
   auto firstScalableDim = it - scalableDimsToUnroll.begin();
-  if (firstScalableDim == 0)
+  if (firstScalableDim == 0) {
+    reportActualRank(vectorRank);
     return {};
+  }
   // All scalable dimensions should be removed now.
   scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
   assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
          "unexpected leading scalable dimension");
   // Create an unroll iterator for leading dimensions.
   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
+  reportActualRank(vectorRank - shapeToUnroll.size());
   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/94197


More information about the Mlir-commits mailing list