[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)

Ivan Butygin llvmlistbot at llvm.org
Sat Mar 15 17:17:12 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/131462

>From 0b8c7b1a943e7255fc76c114b6ac900d8feb23dd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 15 Mar 2025 17:30:29 +0100
Subject: [PATCH 1/2] [mlir][vector] Propagate `vector.extract` through
 elementwise ops

Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 44 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 42 +++++++++++++++++++++
 2 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..2e326882b5cd1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,6 +2237,47 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2309,7 +2350,8 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromElemetwise>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..7ea461acbcd95 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Dop not propagate extract, as elementwise has other uses
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>

>From ac16b306fc15ee6a761568c57107cc71df3f3993 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 16 Mar 2025 01:09:36 +0100
Subject: [PATCH 2/2] make patterns standalone

---
 .../Vector/TransformOps/VectorTransformOps.td | 11 ++++
 .../Vector/Transforms/VectorRewritePatterns.h |  3 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 44 +------------
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/VectorPropagateExtract.cpp     | 66 +++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 42 ------------
 .../Dialect/Vector/propagate-extracts.mlir    | 47 +++++++++++++
 8 files changed, 134 insertions(+), 85 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
 create mode 100644 mlir/test/Dialect/Vector/propagate-extracts.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..7be39519c1037 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.propagate_extract",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns for propagating `vector.extract` through the
+    vector ops.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..16c66e078821d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
+/// Populates patterns for propagating `vector.extract` through the vector ops.
+void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2e326882b5cd1..8e0e723cf4ed3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,47 +2237,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
-class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp op,
-                                PatternRewriter &rewriter) const override {
-    Operation *eltwise = op.getVector().getDefiningOp();
-
-    // Elementwise op with single result and `extract` is single user.
-    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
-        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
-      return failure();
-
-    // Arguments and result types must match.
-    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
-                                            eltwise->getResultTypes())))
-      return failure();
-
-    Type dstType = op.getType();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(eltwise);
-
-    IRMapping mapping;
-    Location loc = eltwise->getLoc();
-    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
-      Value newArg =
-          rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
-      mapping.map(arg, newArg);
-    }
-
-    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
-    newEltwise->getResult(0).setType(dstType);
-
-    rewriter.replaceOp(op, newEltwise);
-    rewriter.eraseOp(eltwise);
-    return success();
-  }
-};
-
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2350,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
-              ExtractOpFromElemetwise>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..616e563fcdc77 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorPropagateExtractsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8830375f88104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransforms.cpp
   VectorUnroll.cpp
   VectorMaskElimination.cpp
+  VectorPropagateExtract.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
new file mode 100644
index 0000000000000..10f578179bc94
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
@@ -0,0 +1,66 @@
+//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns for vector.extract propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElementwise final
+    : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorPropagateExtractsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractOpFromElementwise>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7ea461acbcd95..bf755b466c7eb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,48 +244,6 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
-// CHECK-LABEL: @extract_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
-// CHECK:   return %[[RES]] : f32
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_vec_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
-func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
-// CHECK:   return %[[RES]] : vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @extract_elementwise_use
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
-// Dop not propagate extract, as elementwise has other uses
-// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
-// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1, %0 : f32, vector<4xf32>
-}
-
-// -----
-
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
new file mode 100644
index 0000000000000..6c6f812c8f6d2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/propagate-extracts.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.propagate_extract
+    } : !transform.any_op
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list