[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 15 09:38:24 PDT 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/131462

Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.

>From c8b7a747114b79d1796160c22cc01d325861dda0 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 15 Mar 2025 17:30:29 +0100
Subject: [PATCH] [mlir][vector] Propagate `vector.extract` through elementwise
 ops

Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 44 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 42 +++++++++++++++++++++
 2 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..2e326882b5cd1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,6 +2237,47 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2309,7 +2350,8 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromElemetwise>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..7ea461acbcd95 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Dop not propagate extract, as elementwise has other uses
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>



More information about the Mlir-commits mailing list