[Mlir-commits] [mlir] [mlir][vector] Move extract_strided_slice canonicalization to folding (PR #135676)
James Newling
llvmlistbot at llvm.org
Wed Apr 16 10:18:11 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/135676
>From 2e027d809cdcad0df674afd6cfae4a053d58359c Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 14 Apr 2025 13:39:40 -0700
Subject: [PATCH 1/4] move canonicalizers to folders
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 161 ++++++++-------------
mlir/test/Dialect/Vector/canonicalize.mlir | 7 +
2 files changed, 66 insertions(+), 102 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..18dbd1167995e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3717,6 +3717,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
+
+ Attribute foldInput = adaptor.getVector();
+ if (!foldInput) {
+ return {};
+ }
+
+ // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+ if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+ DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+
+ // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+ if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
+ // TODO: Handle non-unit strides when they become available.
+ if (hasNonUnitStrides())
+ return {};
+
+ Value sourceVector = getVector();
+ auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+ ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+ SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+ VectorType sliceVecTy = getType();
+ ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+ int64_t sliceRank = sliceVecTy.getRank();
+
+ // Expand offsets and sizes to match the vector rank.
+ SmallVector<int64_t, 4> offsets(sliceRank, 0);
+ copy(getI64SubArray(getOffsets()), offsets.begin());
+
+ SmallVector<int64_t, 4> sizes(sourceShape);
+ copy(getI64SubArray(getSizes()), sizes.begin());
+
+ // Calculate the slice elements by enumerating all slice positions and
+ // linearizing them. The enumeration order is lexicographic which yields a
+ // sequence of monotonically increasing linearized position indices.
+ auto denseValuesBegin = dense.value_begin<Attribute>();
+ SmallVector<Attribute> sliceValues;
+ sliceValues.reserve(sliceVecTy.getNumElements());
+ SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+ do {
+ int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+ assert(linearizedPosition < sourceVecTy.getNumElements() &&
+ "Invalid index");
+ sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+ } while (
+ succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+ assert(static_cast<int64_t>(sliceValues.size()) ==
+ sliceVecTy.getNumElements() &&
+ "Invalid number of slice elements");
+ return DenseElementsAttr::get(sliceVecTy, sliceValues);
+ }
+
return {};
}
@@ -3781,98 +3834,6 @@ class StridedSliceConstantMaskFolder final
}
};
-// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
-class StridedSliceSplatConstantFolder final
- : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
- // ConstantOp.
- Value sourceVector = extractStridedSliceOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst);
- if (!splat)
- return failure();
-
- auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
- splat.getSplatValue<Attribute>());
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
- newAttr);
- return success();
- }
-};
-
-// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
-// ConstantOp.
-class StridedSliceNonSplatConstantFolder final
- : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
- PatternRewriter &rewriter) const override {
- // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
- // ConstantOp.
- Value sourceVector = extractStridedSliceOp.getVector();
- Attribute vectorCst;
- if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
- return failure();
-
- // The splat case is handled by `StridedSliceSplatConstantFolder`.
- auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst);
- if (!dense || dense.isSplat())
- return failure();
-
- // TODO: Handle non-unit strides when they become available.
- if (extractStridedSliceOp.hasNonUnitStrides())
- return failure();
-
- auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
- ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
- SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
- VectorType sliceVecTy = extractStridedSliceOp.getType();
- ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
- int64_t sliceRank = sliceVecTy.getRank();
-
- // Expand offsets and sizes to match the vector rank.
- SmallVector<int64_t, 4> offsets(sliceRank, 0);
- copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
-
- SmallVector<int64_t, 4> sizes(sourceShape);
- copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
-
- // Calculate the slice elements by enumerating all slice positions and
- // linearizing them. The enumeration order is lexicographic which yields a
- // sequence of monotonically increasing linearized position indices.
- auto denseValuesBegin = dense.value_begin<Attribute>();
- SmallVector<Attribute> sliceValues;
- sliceValues.reserve(sliceVecTy.getNumElements());
- SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
- do {
- int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
- assert(linearizedPosition < sourceVecTy.getNumElements() &&
- "Invalid index");
- sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
- } while (
- succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
- assert(static_cast<int64_t>(sliceValues.size()) ==
- sliceVecTy.getNumElements() &&
- "Invalid number of slice elements");
- auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
- newAttr);
- return success();
- }
-};
-
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final
@@ -4016,8 +3977,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
- StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
+ results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
context);
}
@@ -5657,10 +5617,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
- return DenseElementsAttr::get(resultType,
- splatAttr.getSplatValue<Attribute>());
- }
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+ return splatAttr.reshape(getType());
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6004,10 +5962,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
- if (auto attr =
- llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector()))
- if (attr.isSplat())
- return attr.reshape(getResultVectorType());
+ if (auto splat =
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
+ return splat.reshape(getResultVectorType());
// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 78b0ea78849e8..6556df22e069b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<
return %0, %2 : vector<4x8xf32>, vector<2xi32>
}
+// -----
+
// CHECK-LABEL: func @bitcast_f16_to_f32
// bit pattern: 0x40004000
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32>
@@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) {
return %cast0, %cast1: vector<4xf32>, vector<4xf32>
}
+// -----
+
// CHECK-LABEL: func @bitcast_i8_to_i32
// bit pattern: 0xA0A0A0A0
// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32>
@@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
}
// -----
+
// CHECK-LABEL: func.func @vector_multi_reduction_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>,
@@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
return %0 : vector<8x4xf32>
}
+// -----
+
// CHECK-LABEL: func @transpose_splat2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
>From c0cebd2ecad27414d72f619e932c18c2511cba3e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 00:26:26 -0700
Subject: [PATCH 2/4] tidy
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 18dbd1167995e..1eb6daa402422 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3718,32 +3718,31 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
+ // All subsequent successful folds require a constant input.
Attribute foldInput = adaptor.getVector();
- if (!foldInput) {
+ if (!foldInput)
return {};
- }
- // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
+ // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
- // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+ // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
// TODO: Handle non-unit strides when they become available.
if (hasNonUnitStrides())
return {};
- Value sourceVector = getVector();
- auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType());
+ VectorType sourceVecTy = getSourceVectorType();
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
VectorType sliceVecTy = getType();
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
- int64_t sliceRank = sliceVecTy.getRank();
+ int64_t rank = sliceVecTy.getRank();
// Expand offsets and sizes to match the vector rank.
- SmallVector<int64_t, 4> offsets(sliceRank, 0);
+ SmallVector<int64_t, 4> offsets(rank, 0);
copy(getI64SubArray(getOffsets()), offsets.begin());
SmallVector<int64_t, 4> sizes(sourceShape);
@@ -3752,7 +3751,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
// Calculate the slice elements by enumerating all slice positions and
// linearizing them. The enumeration order is lexicographic which yields a
// sequence of monotonically increasing linearized position indices.
- auto denseValuesBegin = dense.value_begin<Attribute>();
+ const auto denseValuesBegin = dense.value_begin<Attribute>();
SmallVector<Attribute> sliceValues;
sliceValues.reserve(sliceVecTy.getNumElements());
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
>From 788f44cf917a4b17bcaa8cacaf7fef1786083bf6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 16:35:34 -0700
Subject: [PATCH 3/4] factorize to function
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 101 ++++++++++++-----------
1 file changed, 53 insertions(+), 48 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1eb6daa402422..29d9d1f0f50ae 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3712,64 +3712,69 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
return failure();
}
+namespace {
+
+// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
+OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
+ Attribute foldInput) {
+
+ auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
+ if (!dense)
+ return {};
+
+ // TODO: Handle non-unit strides when they become available.
+ if (op.hasNonUnitStrides())
+ return {};
+
+ VectorType sourceVecTy = op.getSourceVectorType();
+ ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
+ SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
+
+ VectorType sliceVecTy = op.getType();
+ ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+ int64_t rank = sliceVecTy.getRank();
+
+ // Expand offsets and sizes to match the vector rank.
+ SmallVector<int64_t, 4> offsets(rank, 0);
+ copy(getI64SubArray(op.getOffsets()), offsets.begin());
+
+ SmallVector<int64_t, 4> sizes(sourceShape);
+ copy(getI64SubArray(op.getSizes()), sizes.begin());
+
+ // Calculate the slice elements by enumerating all slice positions and
+ // linearizing them. The enumeration order is lexicographic which yields a
+ // sequence of monotonically increasing linearized position indices.
+ const auto denseValuesBegin = dense.value_begin<Attribute>();
+ SmallVector<Attribute> sliceValues;
+ sliceValues.reserve(sliceVecTy.getNumElements());
+ SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
+ do {
+ int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
+ assert(linearizedPosition < sourceVecTy.getNumElements() &&
+ "Invalid index");
+ sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
+ } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
+
+ assert(static_cast<int64_t>(sliceValues.size()) ==
+ sliceVecTy.getNumElements() &&
+ "Invalid number of slice elements");
+ return DenseElementsAttr::get(sliceVecTy, sliceValues);
+}
+} // namespace
+
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
if (getSourceVectorType() == getResult().getType())
return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
- // All subsequent successful folds require a constant input.
- Attribute foldInput = adaptor.getVector();
- if (!foldInput)
- return {};
-
// ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
- if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput))
+ if (auto splat =
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
- if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) {
- // TODO: Handle non-unit strides when they become available.
- if (hasNonUnitStrides())
- return {};
-
- VectorType sourceVecTy = getSourceVectorType();
- ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
- SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
-
- VectorType sliceVecTy = getType();
- ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
- int64_t rank = sliceVecTy.getRank();
-
- // Expand offsets and sizes to match the vector rank.
- SmallVector<int64_t, 4> offsets(rank, 0);
- copy(getI64SubArray(getOffsets()), offsets.begin());
-
- SmallVector<int64_t, 4> sizes(sourceShape);
- copy(getI64SubArray(getSizes()), sizes.begin());
-
- // Calculate the slice elements by enumerating all slice positions and
- // linearizing them. The enumeration order is lexicographic which yields a
- // sequence of monotonically increasing linearized position indices.
- const auto denseValuesBegin = dense.value_begin<Attribute>();
- SmallVector<Attribute> sliceValues;
- sliceValues.reserve(sliceVecTy.getNumElements());
- SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
- do {
- int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
- assert(linearizedPosition < sourceVecTy.getNumElements() &&
- "Invalid index");
- sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
- } while (
- succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
-
- assert(static_cast<int64_t>(sliceValues.size()) ==
- sliceVecTy.getNumElements() &&
- "Invalid number of slice elements");
- return DenseElementsAttr::get(sliceVecTy, sliceValues);
- }
-
- return {};
+ return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
}
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
>From fc62a9b2445a581246fe1c14f602ad8041fe0ee4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 16 Apr 2025 10:17:58 -0700
Subject: [PATCH 4/4] prefer static to anonymous namespace in this file
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 29d9d1f0f50ae..bbe222e72bf24 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3712,11 +3712,10 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
return failure();
}
-namespace {
-
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
-OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
- Attribute foldInput) {
+static OpFoldResult
+foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
+ Attribute foldInput) {
auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
if (!dense)
@@ -3760,7 +3759,6 @@ OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
"Invalid number of slice elements");
return DenseElementsAttr::get(sliceVecTy, sliceValues);
}
-} // namespace
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
if (getSourceVectorType() == getResult().getType())
More information about the Mlir-commits
mailing list