[Mlir-commits] [mlir] [mlir][vector] Move extract_strided_slice canonicalization to folding (PR #135676)

James Newling llvmlistbot at llvm.org
Tue Apr 15 00:26:47 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135676

>From 2e027d809cdcad0df674afd6cfae4a053d58359c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 14 Apr 2025 13:39:40 -0700
Subject: [PATCH 1/2] move canonicalizers to folders

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 161 ++++++++-------------
 mlir/test/Dialect/Vector/canonicalize.mlir |   7 +
 2 files changed, 66 insertions(+), 102 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..18dbd1167995e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3717,6 +3717,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
+
+  Attribute foldInput = adaptor.getVector();
+  if (!foldInput) {
+    return {};
+  }
+
+  // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+  // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+    // TODO: Handle non-unit strides when they become available.
+    if (hasNonUnitStrides())
+      return {};
+
+    Value sourceVector = getVector();
+    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+    VectorType sliceVecTy = getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t sliceRank = sliceVecTy.getRank();
+
+    // Expand offsets and sizes to match the vector rank.
+    SmallVector<int64_t, 4> offsets(sliceRank, 0);
+    copy(getI64SubArray(getOffsets()), offsets.begin());
+
+    SmallVector<int64_t, 4> sizes(sourceShape);
+    copy(getI64SubArray(getSizes()), sizes.begin());
+
+    // Calculate the slice elements by enumerating all slice positions and
+    // linearizing them. The enumeration order is lexicographic which yields a
+    // sequence of monotonically increasing linearized position indices.
+    auto denseValuesBegin = dense.value_begin<Attribute>();
+    SmallVector<Attribute> sliceValues;
+    sliceValues.reserve(sliceVecTy.getNumElements());
+    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+      assert(linearizedPosition < sourceVecTy.getNumElements() &&
+             "Invalid index");
+      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+    assert(static_cast<int64_t>(sliceValues.size()) ==
+               sliceVecTy.getNumElements() &&
+           "Invalid number of slice elements");
+    return DenseElementsAttr::get(sliceVecTy, sliceValues);
+  }
+
   return {};
 }
 
@@ -3781,98 +3834,6 @@ class StridedSliceConstantMaskFolder final
   }
 };
 
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
-    if (!splat)
-      return failure();
-
-    auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
-                                          splat.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
-                                PatternRewriter &rewriter) const override {
-    // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
-    // ConstantOp.
-    Value sourceVector = extractStridedSliceOp.getVector();
-    Attribute vectorCst;
-    if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
-      return failure();
-
-    // The splat case is handled by `StridedSliceSplatConstantFolder`.
-    auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
-    if (!dense || dense.isSplat())
-      return failure();
-
-    // TODO: Handle non-unit strides when they become available.
-    if (extractStridedSliceOp.hasNonUnitStrides())
-      return failure();
-
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
-    ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
-    SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
-    VectorType sliceVecTy = extractStridedSliceOp.getType();
-    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
-
-    // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
-    copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
-    SmallVector<int64_t, 4> sizes(sourceShape);
-    copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
-    // Calculate the slice elements by enumerating all slice positions and
-    // linearizing them. The enumeration order is lexicographic which yields a
-    // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
-    SmallVector<Attribute> sliceValues;
-    sliceValues.reserve(sliceVecTy.getNumElements());
-    SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
-    do {
-      int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
-      assert(linearizedPosition < sourceVecTy.getNumElements() &&
-             "Invalid index");
-      sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (
-        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
-    assert(static_cast<int64_t>(sliceValues.size()) ==
-               sliceVecTy.getNumElements() &&
-           "Invalid number of slice elements");
-    auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
-                                                   newAttr);
-    return success();
-  }
-};
-
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
 // BroadcastOp(ExtractStrideSliceOp).
 class StridedSliceBroadcast final
@@ -4016,8 +3977,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
-              StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
               StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
       context);
 }
@@ -5657,10 +5617,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
-    return DenseElementsAttr::get(resultType,
-                                  splatAttr.getSplatValue<Attribute>());
-  }
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+    return splatAttr.reshape(getType());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6004,10 +5962,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
 
 OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
-  if (auto attr =
-          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
-    if (attr.isSplat())
-      return attr.reshape(getResultVectorType());
+  if (auto splat =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+    return splat.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..6556df22e069b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
   return %0, %2 : vector<4x8xf32>, vector<2xi32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_f16_to_f32
 //              bit pattern: 0x40004000
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
   return %cast0, %cast1: vector<4xf32>, vector<4xf32>
 }
 
+// -----
+
 // CHECK-LABEL: func @bitcast_i8_to_i32
 //              bit pattern: 0xA0A0A0A0
 //       CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 }
 
 // -----
+
 // CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
 // CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
 // CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
   return %0 : vector<8x4xf32>
 }
 
+// -----
+
 // CHECK-LABEL:   func @transpose_splat2(
 // CHECK-SAME:                           %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
 // CHECK:           %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>

>From c0cebd2ecad27414d72f619e932c18c2511cba3e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 00:26:26 -0700
Subject: [PATCH 2/2] tidy

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 18dbd1167995e..1eb6daa402422 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3718,32 +3718,31 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
 
+  // All subsequent successful folds require a constant input.
   Attribute foldInput = adaptor.getVector();
-  if (!foldInput) {
+  if (!foldInput)
     return {};
-  }
 
-  // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+  // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
   if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
     DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
 
-  // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+  // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
   if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
     // TODO: Handle non-unit strides when they become available.
     if (hasNonUnitStrides())
       return {};
 
-    Value sourceVector = getVector();
-    auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+    VectorType sourceVecTy = getSourceVectorType();
     ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
     SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
 
     VectorType sliceVecTy = getType();
     ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
-    int64_t sliceRank = sliceVecTy.getRank();
+    int64_t rank = sliceVecTy.getRank();
 
     // Expand offsets and sizes to match the vector rank.
-    SmallVector<int64_t, 4> offsets(sliceRank, 0);
+    SmallVector<int64_t, 4> offsets(rank, 0);
     copy(getI64SubArray(getOffsets()), offsets.begin());
 
     SmallVector<int64_t, 4> sizes(sourceShape);
@@ -3752,7 +3751,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
     // Calculate the slice elements by enumerating all slice positions and
     // linearizing them. The enumeration order is lexicographic which yields a
     // sequence of monotonically increasing linearized position indices.
-    auto denseValuesBegin = dense.value_begin<Attribute>();
+    const auto denseValuesBegin = dense.value_begin<Attribute>();
     SmallVector<Attribute> sliceValues;
     sliceValues.reserve(sliceVecTy.getNumElements());
     SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());



More information about the Mlir-commits mailing list