[Mlir-commits] [mlir] [mlir][vector] Move hidden function to op definition (PR #140813)

James Newling llvmlistbot at llvm.org
Thu May 22 12:06:05 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/140813

>From 207b0e11365b61db78896741d17d2505f8c3f891 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 20 May 2025 15:25:12 -0700
Subject: [PATCH 1/2] first commit in preparation

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   4 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 114 ++++++++++++++++
 .../Vector/Transforms/VectorLinearize.cpp     | 126 ++----------------
 3 files changed, 132 insertions(+), 112 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..f8412863b18c9 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1114,6 +1114,8 @@ def Vector_InsertStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
+    // \return The indices in dest that the values are inserted to.
+    FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
 
   let hasFolder = 1;
@@ -1254,6 +1256,8 @@ def Vector_ExtractStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
+    // \return The indices in source that the values are taken from.
+    FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..e800b7b7c9ff6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3182,6 +3182,101 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                       stridesAttr);
 }
 
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+  if (!attrs)
+    return failure();
+  SmallVector<int64_t> ints;
+  ints.reserve(attrs.size());
+  for (auto attr : attrs) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+      ints.push_back(intAttr.getInt());
+    } else {
+      return failure();
+    }
+  }
+  return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+static SmallVector<int64_t>
+getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                ArrayRef<int64_t> large,
+                                ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert((large.size() >= small.size()) &&
+         "rank of 'large' cannot be lower than rank of 'small'");
+  assert((large.size() >= offsets.size()) &&
+         "rank of 'large' cannot be lower than the number of offsets");
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) -> int64_t {
+    return i >= delta ? small[i - delta] : 1;
+  };
+  auto getOffset = [&](int64_t i) -> int64_t {
+    return i < nOffsets ? offsets[i] : 0;
+  };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    int64_t currentSize = indices.size();
+    int64_t smallSize = getSmall(i);
+    int64_t nextSize = currentSize * smallSize;
+    SmallVector<int64_t> nextIndices(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (int k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    indices = std::move(nextIndices);
+  }
+  return indices;
+}
+
+FailureOr<SmallVector<int64_t>> InsertStridedSliceOp::getLinearIndices() {
+
+  // Stride > 1 to be considered if/when the insert_strided_slice supports it.
+  if (hasNonUnitStrides())
+    return failure();
+
+  // Only when the destination has a static size can the indices be enumerated.
+  if (getType().isScalable())
+    return failure();
+
+  // Only when the offsets are all static can the indices be enumerated.
+  FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+  if (failed(offsets))
+    return failure();
+
+  return getStridedSliceInsertionIndices(getSourceVectorType().getShape(),
+                                         getDestVectorType().getShape(),
+                                         offsets.value());
+}
+
 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
 template <typename OpType>
 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
@@ -3638,6 +3733,25 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                       stridesAttr);
 }
 
+FailureOr<SmallVector<int64_t>> ExtractStridedSliceOp::getLinearIndices() {
+
+  // Stride > 1 to be considered if/when extract_strided_slice supports it.
+  if (hasNonUnitStrides())
+    return failure();
+
+  // Only when the source has a static size can the indices be enumerated.
+  if (getSourceVectorType().isScalable())
+    return failure();
+
+  // Only when the offsets are all static can the indices be enumerated.
+  FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+  if (failed(offsets))
+    return failure();
+
+  return getStridedSliceInsertionIndices(
+      getType().getShape(), getSourceVectorType().getShape(), offsets.value());
+}
+
 LogicalResult ExtractStridedSliceOp::verify() {
   auto type = getSourceVectorType();
   auto offsets = getOffsetsAttr();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..6cf818bbd0695 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
   }
 };
 
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
-  static_assert(
-      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
-          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
-      "expected vector.extract_strided_slice or vector.insert_strided_slice");
-  ArrayAttr strides = op.getStrides();
-  return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
-static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
-  if (!attrs)
-    return failure();
-  SmallVector<int64_t> ints;
-  ints.reserve(attrs.size());
-  for (auto attr : attrs) {
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-      ints.push_back(intAttr.getInt());
-    } else {
-      return failure();
-    }
-  }
-  return ints;
-}
-
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
-    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
-    ArrayRef<int64_t> offsets) {
-
-  // Example of alignment between, `large`, `small` and `offsets`:
-  //    large  =  4, 5, 6, 7, 8
-  //    small  =     1, 6, 7, 8
-  //  offsets  =  2, 3, 0
-  //
-  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
-  assert((large.size() >= small.size()) &&
-         "rank of 'large' cannot be lower than rank of 'small'");
-  assert((large.size() >= offsets.size()) &&
-         "rank of 'large' cannot be lower than the number of offsets");
-  unsigned delta = large.size() - small.size();
-  unsigned nOffsets = offsets.size();
-  auto getSmall = [&](int64_t i) -> int64_t {
-    return i >= delta ? small[i - delta] : 1;
-  };
-  auto getOffset = [&](int64_t i) -> int64_t {
-    return i < nOffsets ? offsets[i] : 0;
-  };
-
-  // Using 2 vectors of indices, at each iteration populate the updated set of
-  // indices based on the old set of indices, and the size of the small vector
-  // in the current iteration.
-  SmallVector<int64_t> indices{0};
-  int64_t stride = 1;
-  for (int i = large.size() - 1; i >= 0; --i) {
-    int64_t currentSize = indices.size();
-    int64_t smallSize = getSmall(i);
-    int64_t nextSize = currentSize * smallSize;
-    SmallVector<int64_t> nextIndices(nextSize);
-    int64_t *base = nextIndices.begin();
-    int64_t offset = getOffset(i) * stride;
-    for (int j = 0; j < smallSize; ++j) {
-      for (int k = 0; k < currentSize; ++k) {
-        base[k] = indices[k] + offset;
-      }
-      offset += stride;
-      base += currentSize;
-    }
-    stride *= large[i];
-    indices = std::move(nextIndices);
-  }
-  return indices;
-}
-
 /// This pattern converts a vector.extract_strided_slice operation into a
 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
 ///
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for extract_strided_slice allows non-1 strides).
-    if (!stridesAllOne(extractStridedSliceOp)) {
+    if (extractStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           extractStridedSliceOp,
           "extract_strided_slice with strides != 1 not supported");
     }
 
-    FailureOr<SmallVector<int64_t>> offsets =
-        intsFromArrayAttr(extractStridedSliceOp.getOffsets());
-    if (failed(offsets)) {
+    FailureOr<SmallVector<int64_t>> indices =
+        extractStridedSliceOp.getLinearIndices();
+    if (failed(indices)) {
       return rewriter.notifyMatchFailure(extractStridedSliceOp,
-                                         "failed to get integer offsets");
+                                         "failed to get indices");
     }
 
-    ArrayRef<int64_t> inputShape =
-        extractStridedSliceOp.getSourceVectorType().getShape();
-
-    ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
-
-    SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
-        outputShape, inputShape, offsets.value());
-
     Value srcVector = adaptor.getVector();
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractStridedSliceOp,
+                                                   flatOutputType, srcVector,
+                                                   srcVector, indices.value());
     return success();
   }
 };
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for insert_strided_slice allows non-1 strides).
-    if (!stridesAllOne(insertStridedSliceOp)) {
+    if (insertStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           insertStridedSliceOp,
           "insert_strided_slice with strides != 1 not supported");
     }
 
-    VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-
     VectorType outputType = insertStridedSliceOp.getType();
-    ArrayRef<int64_t> outputShape = outputType.getShape();
     int64_t nOutputElements = outputType.getNumElements();
 
-    FailureOr<SmallVector<int64_t>> offsets =
-        intsFromArrayAttr(insertStridedSliceOp.getOffsets());
-    if (failed(offsets)) {
+    FailureOr<SmallVector<int64_t>> sliceIndices =
+        insertStridedSliceOp.getLinearIndices();
+    if (failed(sliceIndices))
       return rewriter.notifyMatchFailure(insertStridedSliceOp,
-                                         "failed to get integer offsets");
-    }
-    SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
-        inputShape, outputShape, offsets.value());
+                                         "failed to get indices");
 
     SmallVector<int64_t> indices(nOutputElements);
     std::iota(indices.begin(), indices.end(), 0);
-    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices.value())) {
       indices[sliceIndex] = index + nOutputElements;
     }
 

>From 1cc6345d5cde6972b3bd1d1708f0c4f152af349b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 22 May 2025 12:06:23 -0700
Subject: [PATCH 2/2] comment improvement

---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f8412863b18c9..481523ff10c3f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1114,7 +1114,7 @@ def Vector_InsertStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
-    // \return The indices in dest that the values are inserted to.
+    // \return The indices in `dest` where values are stored.
     FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
 
@@ -1256,7 +1256,7 @@ def Vector_ExtractStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
-    // \return The indices in source that the values are taken from.
+    // \return The indices in `source` where values are extracted.
     FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
   let hasCanonicalizer = 1;



More information about the Mlir-commits mailing list