[Mlir-commits] [mlir] [mlir][vector] Propagate `vector.extract` through elementwise ops (PR #131462)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 15 17:20:38 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.
Currenly limited to the case when extract is the single use of elementwise to avoid introducing additional elementwise ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/131462.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+11)
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+3)
- (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
- (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp (+66)
- (added) mlir/test/Dialect/Vector/propagate-extracts.mlir (+47)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..7be39519c1037 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,15 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyVectorPropagateExtractPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.propagate_extract",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Collect a set of patterns for propagating `vector.extract` through the
+ vector ops.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 7de4a6a315750..16c66e078821d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -409,6 +409,9 @@ void populateVectorLinearizeShuffleLikeOpsPatterns(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, unsigned targetBitWidth);
+/// Populates patterns for propagating `vector.extract` through the vector ops.
+void populateVectorPropagateExtractsPatterns(RewritePatternSet &patterns);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..616e563fcdc77 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -204,6 +204,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
}
+void transform::ApplyVectorPropagateExtractPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorPropagateExtractsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8830375f88104 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
VectorTransforms.cpp
VectorUnroll.cpp
VectorMaskElimination.cpp
+ VectorPropagateExtract.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Transforms
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
new file mode 100644
index 0000000000000..10f578179bc94
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorPropagateExtract.cpp
@@ -0,0 +1,66 @@
+//===- VectorPropagateExtract.cpp - vector.extract propagation - ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns for vector.extract propagation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+class ExtractOpFromElementwise final
+ : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *eltwise = op.getVector().getDefiningOp();
+
+ // Elementwise op with single result and `extract` is single user.
+ if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+ eltwise->getNumResults() != 1 || !eltwise->hasOneUse())
+ return failure();
+
+ // Arguments and result types must match.
+ if (!llvm::all_equal(llvm::concat<Type>(eltwise->getOperandTypes(),
+ eltwise->getResultTypes())))
+ return failure();
+
+ Type dstType = op.getType();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(eltwise);
+
+ IRMapping mapping;
+ Location loc = eltwise->getLoc();
+ for (auto &&[i, arg] : llvm::enumerate(eltwise->getOperands())) {
+ Value newArg =
+ rewriter.create<vector::ExtractOp>(loc, arg, op.getMixedPosition());
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+ newEltwise->getResult(0).setType(dstType);
+
+ rewriter.replaceOp(op, newEltwise);
+ rewriter.eraseOp(eltwise);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorPropagateExtractsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ExtractOpFromElementwise>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Vector/propagate-extracts.mlir b/mlir/test/Dialect/Vector/propagate-extracts.mlir
new file mode 100644
index 0000000000000..6c6f812c8f6d2
--- /dev/null
+++ b/mlir/test/Dialect/Vector/propagate-extracts.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: @extract_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: return %[[RES]] : f32
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_vec_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_vec_elementwise(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @extract_elementwise_use
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.propagate_extract
+ } : !transform.any_op
+ transform.yield
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/131462
More information about the Mlir-commits
mailing list