[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)

Ivan Butygin llvmlistbot at llvm.org
Sun Mar 23 08:10:12 PDT 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/131462

>From 8f5ad58f074bf9b1720012191853d4f359efcb58 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sat, 15 Mar 2025 17:30:29 +0100
Subject: [PATCH 01/10] [mlir][vector] Propagate `vector.extract` through
 elementwise ops

Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.

Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional computations.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 44 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 42 +++++++++++++++++++++
 2 files changed, 85 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d4c1da30d498d..bdb42efccab53 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,6 +2237,47 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2309,7 +2350,8 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromElemetwise>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index bf755b466c7eb..7ea461acbcd95 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Dop not propagate extract, as elementwise has other uses
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>

>From 11fe5a33b3becf7c483821e582c02ae650f07ae6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 16 Mar 2025 01:09:36 +0100
Subject: [PATCH 02/10] make patterns standalone

---
 .../Vector/TransformOps/VectorTransformOps.td | 11 ++++
 .../Vector/Transforms/VectorRewritePatterns.h |  3 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 44 +------------
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/VectorPropagateExtract.cpp     | 66 +++++++++++++++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 42 ------------
 .../Dialect/Vector/propagate-extracts.mlir    | 47 +++++++++++++
 8 files changed, 134 insertions(+), 85 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
 create mode 100644 mlir/test/Dialect/Vector/propagate-extracts.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..7be39519c1037 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.propagate_extract",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collect a set of patterns for propagating `vector.extract` through the
+    vector ops.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..16c66e078821d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
+/// Populates patterns for propagating `vector.extract` through the vector ops.
+void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bdb42efccab53..d4c1da30d498d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2237,47 +2237,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
-/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
-class ExtractOpFromElemetwise final : public OpRewritePattern<ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractOp op,
-                                PatternRewriter &rewriter) const override {
-    Operation *eltwise = op.getVector().getDefiningOp();
-
-    // Elementwise op with single result and `extract` is single user.
-    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
-        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
-      return failure();
-
-    // Arguments and result types must match.
-    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
-                                            eltwise->getResultTypes())))
-      return failure();
-
-    Type dstType = op.getType();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(eltwise);
-
-    IRMapping mapping;
-    Location loc = eltwise->getLoc();
-    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
-      Value newArg =
-          rewriter.create<ExtractOp>(loc, arg, op.getMixedPosition());
-      mapping.map(arg, newArg);
-    }
-
-    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
-    newEltwise->getResult(0).setType(dstType);
-
-    rewriter.replaceOp(op, newEltwise);
-    rewriter.eraseOp(eltwise);
-    return success();
-  }
-};
-
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2350,8 +2309,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
-              ExtractOpFromElemetwise>(context);
+  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..616e563fcdc77 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorPropagateExtractsPatterns(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8830375f88104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransforms.cpp
   VectorUnroll.cpp
   VectorMaskElimination.cpp
+  VectorPropagateExtract.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
new file mode 100644
index 0000000000000..10f578179bc94
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
@@ -0,0 +1,66 @@
+//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns for vector.extract propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElementwise final
+    : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+                                            eltwise->getResultTypes())))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorPropagateExtractsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractOpFromElementwise>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 7ea461acbcd95..bf755b466c7eb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -244,48 +244,6 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
 
 // -----
 
-// CHECK-LABEL: @extract_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
-// CHECK:   return %[[RES]] : f32
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// -----
-
-// CHECK-LABEL: @extract_vec_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
-func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
-// CHECK:   return %[[RES]] : vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @extract_elementwise_use
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
-// Dop not propagate extract, as elementwise has other uses
-// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
-// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1, %0 : f32, vector<4xf32>
-}
-
-// -----
-
 // CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
 func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
   //     CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
new file mode 100644
index 0000000000000..6c6f812c8f6d2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/propagate-extracts.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.propagate_extract
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From c92969b30152503c3aab2c88973645e476fb072b Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 18 Mar 2025 16:38:45 +0100
Subject: [PATCH 03/10] move patterns to vector-sink

---
 .../Vector/TransformOps/VectorTransformOps.td | 20 ++++--
 .../Vector/Transforms/VectorRewritePatterns.h |  3 -
 .../TransformOps/VectorTransformOps.cpp       |  4 +-
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 -
 .../Transforms/VectorPropagateExtract.cpp     | 66 -------------------
 .../Vector/Transforms/VectorTransforms.cpp    | 48 +++++++++++++-
 .../Linalg/vectorize-tensor-extract.mlir      | 31 ++++-----
 .../Dialect/Vector/propagate-extracts.mlir    | 47 -------------
 .../Dialect/Vector/vector-sink-transform.mlir | 26 ++++++++
 mlir/test/Dialect/Vector/vector-sink.mlir     | 38 +++++++++++
 10 files changed, 140 insertions(+), 144 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
 delete mode 100644 mlir/test/Dialect/Vector/propagate-extracts.mlir
 create mode 100644 mlir/test/Dialect/Vector/vector-sink-transform.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 7be39519c1037..f46aa0428f12f 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,12 +453,24 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
-    "apply_patterns.vector.propagate_extract",
+def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.sink_ops",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Collect a set of patterns for propagating `vector.extract` through the
-    vector ops.
+    Patterns that remove redundant Vector Ops by re-ordering them with
+    e.g. elementwise Ops:
+    ```
+    %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+    %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+    %r = arith.addf %at, %bt : vector<2x4xf32>
+    ```
+    gets converted to:
+    ```
+    %0 = arith.addf %a, %b : vector<4x2xf32>
+    %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
+    ```
+    At the moment, these patterns are limited to vector.broadcast and
+    vector.transpose.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 16c66e078821d..7de4a6a315750 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,9 +409,6 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target, unsigned targetBitWidth);
 
-/// Populates patterns for propagating `vector.extract` through the vector ops.
-void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
-
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 616e563fcdc77..80a6ffa30916a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,9 +204,9 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
-void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+void transform::ApplySinkVectorPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorPropagateExtractsPatterns(patterns);
+  vector::populateSinkVectorOpsPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8830375f88104..8ca5cb6c6dfab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,7 +24,6 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorTransforms.cpp
   VectorUnroll.cpp
   VectorMaskElimination.cpp
-  VectorPropagateExtract.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
deleted file mode 100644
index 10f578179bc94..0000000000000
--- a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements patterns for vector.extract propagation.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-
-using namespace mlir;
-
-namespace {
-
-/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
-class ExtractOpFromElementwise final
-    : public OpRewritePattern<vector::ExtractOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ExtractOp op,
-                                PatternRewriter &rewriter) const override {
-    Operation *eltwise = op.getVector().getDefiningOp();
-
-    // Elementwise op with single result and `extract` is single user.
-    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
-        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
-      return failure();
-
-    // Arguments and result types must match.
-    if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
-                                            eltwise->getResultTypes())))
-      return failure();
-
-    Type dstType = op.getType();
-
-    OpBuilder::InsertionGuard g(rewriter);
-    rewriter.setInsertionPoint(eltwise);
-
-    IRMapping mapping;
-    Location loc = eltwise->getLoc();
-    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
-      Value newArg =
-          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
-      mapping.map(arg, newArg);
-    }
-
-    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
-    newEltwise->getResult(0).setType(dstType);
-
-    rewriter.replaceOp(op, newEltwise);
-    rewriter.eraseOp(eltwise);
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::vector::populateVectorPropagateExtractsPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ExtractOpFromElementwise>(patterns.getContext());
-}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dc46ed17a374d..e633f6373b47d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1043,6 +1043,50 @@ struct ReorderElementwiseOpsOnBroadcast final
   }
 };
 
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+/// This may result in more efficient code when we extracting a single value
+/// from multi-element vector and also to help canonicalize 1-element vectors to
+/// scalars.
+class ExtractOpFromElementwise final
+    : public OpRewritePattern<vector::ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp op,
+                                PatternRewriter &rewriter) const override {
+    Operation *eltwise = op.getVector().getDefiningOp();
+
+    // Elementwise op with single result and `extract` is single user.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+      return failure();
+
+    // Arguments and result types must match.
+    if (!llvm::all_equal(eltwise->getOperandTypes()))
+      return failure();
+
+    Type dstType = op.getType();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(eltwise);
+
+    IRMapping mapping;
+    Location loc = eltwise->getLoc();
+    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+      Value newArg =
+          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+      mapping.map(arg, newArg);
+    }
+
+    Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+    newEltwise->getResult(0).setType(dstType);
+
+    rewriter.replaceOp(op, newEltwise);
+    rewriter.eraseOp(eltwise);
+    return success();
+  }
+};
+
 // Helper that returns a vector comparison that constructs a mask:
 //     mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
 //
@@ -2111,8 +2155,8 @@ void mlir::vector::
 void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
                                                  PatternBenefit benefit) {
   patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
-               ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
-                                                 benefit);
+               ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
+      patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index cd83e1239fdda..b553681953a82 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -62,20 +62,15 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
 // CHECK-SAME:      %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
 // CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
-// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
-// CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
-// CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
-
-// CHECK:           %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
-
-// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
+// CHECK:           %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index
+
+// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_21]] : tensor<1x4xf32>
 // CHECK:         }
 
@@ -101,14 +96,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
 // CHECK-SAME:                                                                        %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                                        %[[VAL_1:.*]]: index,
 // CHECK-SAME:                                                                        %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
-// CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+
+// CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_12]] : tensor<1x4xf32>
 // CHECK:         }
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
deleted file mode 100644
index 6c6f812c8f6d2..0000000000000
--- a/mlir/test/Dialect/Vector/propagate-extracts.mlir
+++ /dev/null
@@ -1,47 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-
-// CHECK-LABEL: @extract_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
-// CHECK:   return %[[RES]] : f32
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
-
-// CHECK-LABEL: @extract_vec_elementwise
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
-func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
-// CHECK:   return %[[RES]] : vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
-  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
-  return %1 : vector<4xf32>
-}
-
-// CHECK-LABEL: @extract_elementwise_use
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
-// Do not propagate extract, as elementwise has other uses.
-// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
-// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
-// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1, %0 : f32, vector<4xf32>
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.propagate_extract
-    } : !transform.any_op
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
new file mode 100644
index 0000000000000..ddf04fa8ae54a
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// This is smoke test for `transform.apply_patterns.vector.sink_ops` the actual
+// patterns are tested in `vector-sink.mlir`.
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.sink_ops
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
+
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 7ce840575a803..a0882038d2af1 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -423,3 +423,41 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
   %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
   return %r : vector<6x[4]x2x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @extract_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK:   return %[[RES]] : f32
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK:   return %[[RES]] : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+  %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+  return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1, %0 : f32, vector<4xf32>
+}

>From eacec46d156b6a66651f11e9dcf6b9c676426626 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 18 Mar 2025 17:14:05 +0100
Subject: [PATCH 04/10] style fixes

---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp | 14 ++++++++++++--
 .../test/Dialect/Vector/vector-sink-transform.mlir |  4 ++--
 mlir/test/Dialect/Vector/vector-sink.mlir          | 12 ++++++------
 3 files changed, 20 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index e633f6373b47d..de3fa45b9b47b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1047,6 +1047,16 @@ struct ReorderElementwiseOpsOnBroadcast final
 /// This may result in more efficient code when we extracting a single value
 /// from multi-element vector and also to help canonicalize 1-element vectors to
 /// scalars.
+/// ```
+///  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+///  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+///  %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
+///  %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
+///  %2 = arith.addf %0, %1 : f32
+/// ```
 class ExtractOpFromElementwise final
     : public OpRewritePattern<vector::ExtractOp> {
 public:
@@ -1061,7 +1071,7 @@ class ExtractOpFromElementwise final
         eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
       return failure();
 
-    // Arguments and result types must match.
+    // Arguments types must match.
     if (!llvm::all_equal(eltwise->getOperandTypes()))
       return failure();
 
@@ -1072,7 +1082,7 @@ class ExtractOpFromElementwise final
 
     IRMapping mapping;
     Location loc = eltwise->getLoc();
-    for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+    for (auto arg : eltwise->getOperands()) {
       Value newArg =
           rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
       mapping.map(arg, newArg);
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index ddf04fa8ae54a..0fc4fd8c804e9 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -13,9 +13,9 @@ module attributes {transform.with_named_sequence} {
 }
 
 
-// CHECK-LABEL: @extract_elementwise
+// CHECK-LABEL: @extract_elementwise_scalar
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 // CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
 // CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
 // CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index a0882038d2af1..75fc2c7af06b3 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -426,9 +426,9 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
 
 // -----
 
-// CHECK-LABEL: @extract_elementwise
+// CHECK-LABEL: @extract_elementwise_scalar
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 // CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
 // CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
 // CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
@@ -438,9 +438,9 @@ func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f3
   return %1 : f32
 }
 
-// CHECK-LABEL: @extract_vec_elementwise
+// CHECK-LABEL: @extract_elementwise_vec
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
-func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
 // CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
 // CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
 // CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
@@ -450,9 +450,9 @@ func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32
   return %1 : vector<4xf32>
 }
 
-// CHECK-LABEL: @extract_elementwise_use
+// CHECK-LABEL: @extract_elementwise_no_single_use
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
 // Do not propagate extract, as elementwise has other uses.
 // CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
 // CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>

>From 72ef7bba725c8c75a5644e8bd1a628a80a7695ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 18 Mar 2025 17:35:48 +0100
Subject: [PATCH 05/10] pattern cleanup

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index de3fa45b9b47b..10da3d3f874e3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1082,9 +1082,9 @@ class ExtractOpFromElementwise final
 
     IRMapping mapping;
     Location loc = eltwise->getLoc();
-    for (auto arg : eltwise->getOperands()) {
-      Value newArg =
-          rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+    SmallVector<OpFoldResult> pos = op.getMixedPosition();
+    for (Value arg : eltwise->getOperands()) {
+      Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
       mapping.map(arg, newArg);
     }
 

>From dde17737abfceac76edb31f4c79d64011680c23e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 18 Mar 2025 17:36:02 +0100
Subject: [PATCH 06/10] more tests

---
 mlir/test/Dialect/Vector/vector-sink.mlir | 35 +++++++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 75fc2c7af06b3..4c2826416056e 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -438,6 +438,17 @@ func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>
   return %1 : f32
 }
 
+// CHECK-LABEL: @extract_elementwise_arg_res_different_types
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xindex>)
+func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 {
+// CHECK:   %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
+// CHECK:   %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
+// CHECK:   return %[[RES]] : i64
+  %0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64>
+  %1 = vector.extract %0[1] : i64 from vector<4xi64>
+  return %1 : i64
+}
+
 // CHECK-LABEL: @extract_elementwise_vec
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
 func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
@@ -461,3 +472,27 @@ func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector
   %1 = vector.extract %0[1] : f32 from vector<4xf32>
   return %1, %0 : f32, vector<4xf32>
 }
+
+// CHECK-LABEL: @extract_elementwise_not_one_res
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
+func.func @extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+// Do not propagate extract, as elementwise has more than 1 result.
+// CHECK:   %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
+// CHECK:   %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
+// CHECK:   return %[[EXT]] : i32
+  %low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
+  %1 = vector.extract %low[1] : i32 from vector<4xi32>
+  return %1 : i32
+}
+
+// CHECK-LABEL: @extract_not_elementwise
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xi64>)
+func.func @extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
+// `test.increment` is not an elemewise op.
+// CHECK:   %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
+// CHECK:   %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
+// CHECK:   return %[[RES]] : i64
+  %0 = test.increment %arg0: vector<4xi64>
+  %1 = vector.extract %0[1] : i64 from vector<4xi64>
+  return %1 : i64
+}

>From bf628cc70a24743503bbbcf80758241339355311 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Tue, 18 Mar 2025 17:37:59 +0100
Subject: [PATCH 07/10] notifyMatchFailure

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 10da3d3f874e3..29bc913af15a7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1069,11 +1069,11 @@ class ExtractOpFromElementwise final
     // Elementwise op with single result and `extract` is single user.
     if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
         eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
-      return failure();
+      return rewriter.notifyMatchFailure(op, "not a suitable op");
 
     // Arguments types must match.
     if (!llvm::all_equal(eltwise->getOperandTypes()))
-      return failure();
+      return rewriter.notifyMatchFailure(op, "arg types are different");
 
     Type dstType = op.getType();
 

>From 64ea88ddc63a7e8173a824887ee9214f0d4cc689 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 23 Mar 2025 00:10:34 +0100
Subject: [PATCH 08/10] review comments

---
 .../TransformOps/VectorTransformOps.cpp       |  3 ++
 .../Vector/Transforms/VectorTransforms.cpp    | 17 ++++++----
 .../Linalg/vectorize-tensor-extract.mlir      | 34 +++++++++----------
 .../Dialect/Vector/vector-sink-transform.mlir | 19 ++---------
 mlir/test/Dialect/Vector/vector-sink.mlir     | 17 ++++++----
 5 files changed, 44 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 80a6ffa30916a..12dcf768dd928 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
 void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorReductionToContractPatterns(patterns);
+
+  // TODO: As we now have a dedicated transform for
+  // `populateSinkVectorOpsPatterns` we can remove it from here.
   vector::populateSinkVectorOpsPatterns(patterns);
 }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 29bc913af15a7..a7ca9f57b06c4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1044,7 +1044,7 @@ struct ReorderElementwiseOpsOnBroadcast final
 };
 
 /// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
-/// This may result in more efficient code when we extracting a single value
+/// This may result in cleaner code when we extracting a single value
 /// from multi-element vector and also to help canonicalize 1-element vectors to
 /// scalars.
 /// ```
@@ -1066,14 +1066,17 @@ class ExtractOpFromElementwise final
                                 PatternRewriter &rewriter) const override {
     Operation *eltwise = op.getVector().getDefiningOp();
 
-    // Elementwise op with single result and `extract` is single user.
-    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
-        eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
-      return rewriter.notifyMatchFailure(op, "not a suitable op");
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise))
+      return rewriter.notifyMatchFailure(op, "not an elementwise op");
+
+    if (eltwise->getNumResults() != 1)
+      return rewriter.notifyMatchFailure(op, "expected single result");
+
+    if (!eltwise->hasOneUse())
+      return rewriter.notifyMatchFailure(op, "expected single op use");
 
-    // Arguments types must match.
     if (!llvm::all_equal(eltwise->getOperandTypes()))
-      return rewriter.notifyMatchFailure(op, "arg types are different");
+      return rewriter.notifyMatchFailure(op, "operand types are different");
 
     Type dstType = op.getType();
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index b553681953a82..375fa37bd84b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -59,19 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
 
 
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex(
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
-// CHECK-SAME:      %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
-// CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<45x80x16xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index,
+// CHECK-SAME:      %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL6:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL7:.*]] = arith.addi %[[VAL_3]], %[[VAL_4]] : index
+// CHECK:           %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index
 
-// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL6]], %[[C79]], %[[VAL7]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-// CHECK:           return %[[VAL_21]] : tensor<1x4xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           return %[[WRITE]] : tensor<1x4xf32>
 // CHECK:         }
 
 // -----
@@ -93,17 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
 }
 
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME:                                                                        %[[VAL_0:.*]]: tensor<80x16xf32>,
-// CHECK-SAME:                                                                        %[[VAL_1:.*]]: index,
-// CHECK-SAME:                                                                        %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<80x16xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 
-// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 79 : index
+// CHECK-DAG:       %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C79:.*]] = arith.constant 79 : index
 
-// CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_1]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-// CHECK:           return %[[VAL_12]] : tensor<1x4xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK:           return %[[WRITE]] : tensor<1x4xf32>
 // CHECK:         }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
index 0fc4fd8c804e9..ef17b69b2444c 100644
--- a/mlir/test/Dialect/Vector/vector-sink-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s
 
-// This is smoke test for `transform.apply_patterns.vector.sink_ops` the actual
-// patterns are tested in `vector-sink.mlir`.
+// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this
+// file is also used in `vector-sink.mlir`.
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
@@ -11,16 +11,3 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
-
-
-// CHECK-LABEL: @extract_elementwise_scalar
-//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
-// CHECK:   %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
-// CHECK:   %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
-// CHECK:   %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
-// CHECK:   return %[[RES]] : f32
-  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
-  %1 = vector.extract %0[1] : f32 from vector<4xf32>
-  return %1 : f32
-}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 4c2826416056e..b9a4c199fc8b1 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s
 
 //-----------------------------------------------------------------------------
 // [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -426,6 +427,10 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %
 
 // -----
 
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromElementwise]
+//-----------------------------------------------------------------------------
+
 // CHECK-LABEL: @extract_elementwise_scalar
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
 func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
@@ -461,9 +466,9 @@ func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32
   return %1 : vector<4xf32>
 }
 
-// CHECK-LABEL: @extract_elementwise_no_single_use
+// CHECK-LABEL: @negative_extract_elementwise_no_single_use
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
 // Do not propagate extract, as elementwise has other uses.
 // CHECK:   %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
 // CHECK:   %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
@@ -473,9 +478,9 @@ func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector
   return %1, %0 : f32, vector<4xf32>
 }
 
-// CHECK-LABEL: @extract_elementwise_not_one_res
+// CHECK-LABEL: @negative_extract_elementwise_not_one_res
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
-func.func @extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
 // Do not propagate extract, as elementwise has more than 1 result.
 // CHECK:   %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
 // CHECK:   %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
@@ -485,9 +490,9 @@ func.func @extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4
   return %1 : i32
 }
 
-// CHECK-LABEL: @extract_not_elementwise
+// CHECK-LABEL: @negative_extract_not_elementwise
 //  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xi64>)
-func.func @extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
+func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
 // `test.increment` is not an elemewise op.
 // CHECK:   %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
 // CHECK:   %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>

>From a65f0727a96b796f39dfc0d769c3963f40034288 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 23 Mar 2025 13:35:57 +0100
Subject: [PATCH 09/10] comment

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index a7ca9f57b06c4..ccc4ebd03ab47 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1044,7 +1044,7 @@ struct ReorderElementwiseOpsOnBroadcast final
 };
 
 /// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
-/// This may result in cleaner code when we extracting a single value
+/// This may result in cleaner code when extracting a single value
 /// from multi-element vector and also to help canonicalize 1-element vectors to
 /// scalars.
 /// ```

>From d3281b51a96711142a5015d856ef82667317d5bb Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 23 Mar 2025 16:08:28 +0100
Subject: [PATCH 10/10] fma workaround

---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp   |  5 ++++-
 mlir/test/Dialect/Vector/vector-sink.mlir            | 12 ++++++++++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ccc4ebd03ab47..df2def5abc040 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1066,7 +1066,10 @@ class ExtractOpFromElementwise final
                                 PatternRewriter &rewriter) const override {
     Operation *eltwise = op.getVector().getDefiningOp();
 
-    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise))
+    // TODO: vector::FMAOp is not ElemetwiseMappable eve if it claims to be, as
+    // it doesn't support scalars.
+    if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+        isa<vector::FMAOp>(eltwise))
       return rewriter.notifyMatchFailure(op, "not an elementwise op");
 
     if (eltwise->getNumResults() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b9a4c199fc8b1..8c8f1797aaab6 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -501,3 +501,15 @@ func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
   %1 = vector.extract %0[1] : i64 from vector<4xi64>
   return %1 : i64
 }
+
+// CHECK-LABEL: @negative_extract_vec_fma
+//  CHECK-SAME:   (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>)
+func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 {
+// `vector.fma` doesn't suppport scalars.
+// CHECK:   %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32>
+// CHECK:   %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32>
+// CHECK:   return %[[RES]] : f32
+  %0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}



More information about the Mlir-commits mailing list