[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)
Ivan Butygin
llvmlistbot at llvm.org
Sat Mar 15 09:38:24 PDT 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/131462
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.
Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
>From c8b7a747114b79d1796160c22cc01d325861dda0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 15 Mar 2025 17:30:29 +0100
Subject: [PATCH] [mlir][vector] Propagate `vector.extract` through elementwise
ops
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.
Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 44 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 42 +++++++++++++++++++++
2 files changed, 85 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..2e326882b5cd1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,6 +2237,47 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *eltwise = op.getVector().getDefiningOp();
+
+ // Elementwise op with single result and `extract` is single user.
+ if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+ eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+ return failure();
+
+ // Arguments and result types must match.
+ if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+ eltwise->getResultTypes())))
+ return failure();
+
+ Type dstType = op.getType();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(eltwise);
+
+ IRMapping mapping;
+ Location loc = eltwise->getLoc();
+ for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+ Value newArg =
+ rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+ newEltwise->getResult(0).setType(dstType);
+
+ rewriter.replaceOp(op, newEltwise);
+ rewriter.eraseOp(eltwise);
+ return success();
+ }
+};
+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2309,7 +2350,8 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+ ExtractOpFromElemetwise>(context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..7ea461acbcd95 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
// -----
+// CHECK-LABEL: @extract_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: return %[[RES]] : f32
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vec_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise_use
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Dop not propagate extract, as elementwise has other uses
+// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
More information about the Mlir-commits
mailing list