[Mlir-commits] [mlir] 9b02222 - [mlir][vector] Propagate `vector.extract` through elementwise ops (#131462)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 25 04:07:52 PDT 2025
Author: Ivan Butygin
Date: 2025-03-25T14:07:48+03:00
New Revision: 9b022220b7960946b5ab3be1e5ace32079b47c5c
URL: https://github.com/llvm/llvm-project/commit/9b022220b7960946b5ab3be1e5ace32079b47c5c
DIFF: https://github.com/llvm/llvm-project/commit/9b022220b7960946b5ab3be1e5ace32079b47c5c.diff
LOG: [mlir][vector] Propagate `vector.extract` through elementwise ops (#131462)
Propagate `Extract(Elementwise(...))` -> `Elemetwise(Extract...)`.
Currenly limited to the case when extract is the single use of
elementwise to avoid introducing additional elementwise ops.
Added:
mlir/test/Dialect/Vector/vector-sink-transform.mlir
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
mlir/test/Dialect/Vector/vector-sink.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index c973eca0132a9..f46aa0428f12f 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -453,4 +453,27 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.sink_ops",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Patterns that remove redundant Vector Ops by re-ordering them with
+ e.g. elementwise Ops:
+ ```
+ %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
+ %r = arith.addf %at, %bt : vector<2x4xf32>
+ ```
+ gets converted to:
+ ```
+ %0 = arith.addf %a, %b : vector<4x2xf32>
+ %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
+ ```
+ At the moment, these patterns are limited to vector.broadcast and
+ vector.transpose.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..12dcf768dd928 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -67,6 +67,9 @@ void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
+
+ // TODO: As we now have a dedicated transform for
+ // `populateSinkVectorOpsPatterns` we can remove it from here.
vector::populateSinkVectorOpsPatterns(patterns);
}
@@ -204,6 +207,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
}
+void transform::ApplySinkVectorPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateSinkVectorOpsPatterns(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dc46ed17a374d..b6fac80d871e6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1043,6 +1043,66 @@ struct ReorderElementwiseOpsOnBroadcast final
}
};
+/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
+/// This may result in cleaner code when extracting a single value
+/// from multi-element vector and also to help canonicalize 1-element vectors to
+/// scalars.
+/// ```
+/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
+/// ```
+/// Gets converted to:
+/// ```
+/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
+/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
+/// %2 = arith.addf %0, %1 : f32
+/// ```
+class ExtractOpFromElementwise final
+ : public OpRewritePattern<vector::ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *eltwise = op.getVector().getDefiningOp();
+
+ // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
+ // as it doesn't support scalars.
+ if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
+ isa<vector::FMAOp>(eltwise))
+ return rewriter.notifyMatchFailure(op, "not an elementwise op");
+
+ if (eltwise->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(op, "expected single result");
+
+ if (!eltwise->hasOneUse())
+ return rewriter.notifyMatchFailure(op, "expected single op use");
+
+ if (!llvm::all_equal(eltwise->getOperandTypes()))
+ return rewriter.notifyMatchFailure(op, "operand types are
diff erent");
+
+ Type dstType = op.getType();
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(eltwise);
+
+ IRMapping mapping;
+ Location loc = eltwise->getLoc();
+ SmallVector<OpFoldResult> pos = op.getMixedPosition();
+ for (Value arg : eltwise->getOperands()) {
+ Value newArg = rewriter.create<vector::ExtractOp>(loc, arg, pos);
+ mapping.map(arg, newArg);
+ }
+
+ Operation *newEltwise = rewriter.clone(*eltwise, mapping);
+ newEltwise->getResult(0).setType(dstType);
+
+ rewriter.replaceOp(op, newEltwise);
+ rewriter.eraseOp(eltwise);
+ return success();
+ }
+};
+
// Helper that returns a vector comparison that constructs a mask:
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
//
@@ -2111,8 +2171,8 @@ void mlir::vector::
void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
- ReorderElementwiseOpsOnBroadcast>(patterns.getContext(),
- benefit);
+ ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index cd83e1239fdda..375fa37bd84b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -59,24 +59,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
-// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
-// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
-// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
-
-// CHECK: %[[VAL_19:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
-
-// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
+// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index,
+// CHECK-SAME: %[[ARG5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ARG3]], %[[ARG4]] : index
+
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[ADD1]], %[[C79]], %[[ADD2]]], %[[CST]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG5]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: return %[[WRITE]] : tensor<1x4xf32>
// CHECK: }
// -----
@@ -98,19 +93,17 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: index,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 79 : index
-// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
-// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
+// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
+
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]]{{\[}}%[[C79]], %[[ARG1]]], %[[CST]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG2]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+// CHECK: return %[[WRITE]] : tensor<1x4xf32>
// CHECK: }
// -----
diff --git a/mlir/test/Dialect/Vector/vector-sink-transform.mlir b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
new file mode 100644
index 0000000000000..ef17b69b2444c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-sink-transform.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s
+
+// This is smoke test for `transform.apply_patterns.vector.sink_ops` and this
+// file is also used in `vector-sink.mlir`.
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.sink_ops
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 7ce840575a803..8c8f1797aaab6 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s
//-----------------------------------------------------------------------------
// [Pattern: ReorderElementwiseOpsOnBroadcast]
@@ -423,3 +424,92 @@ func.func @transpose_elementwise_
diff _map_scalable(%a : vector<[4]x6x3x2xf32>, %
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
return %r : vector<6x[4]x2x3xf32>
}
+
+// -----
+
+//-----------------------------------------------------------------------------
+// [Pattern: ExtractOpFromElementwise]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: @extract_elementwise_scalar
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
+// CHECK: return %[[RES]] : f32
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
+
+// CHECK-LABEL: @extract_elementwise_arg_res_
diff erent_types
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>)
+func.func @extract_elementwise_arg_res_
diff erent_types(%arg0: vector<4xindex>) -> i64 {
+// CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
+// CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
+// CHECK: return %[[RES]] : i64
+ %0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64>
+ %1 = vector.extract %0[1] : i64 from vector<4xi64>
+ return %1 : i64
+}
+
+// CHECK-LABEL: @extract_elementwise_vec
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
+func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> {
+// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
+// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
+// CHECK: return %[[RES]] : vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<2x4xf32>
+ %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
+ return %1 : vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_elementwise_no_single_use
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
+func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) {
+// Do not propagate extract, as elementwise has other uses.
+// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
+// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 : vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1, %0 : f32, vector<4xf32>
+}
+
+// CHECK-LABEL: @negative_extract_elementwise_not_one_res
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
+func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
+// Do not propagate extract, as elementwise has more than 1 result.
+// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
+// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
+// CHECK: return %[[EXT]] : i32
+ %low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
+ %1 = vector.extract %low[1] : i32 from vector<4xi32>
+ return %1 : i32
+}
+
+// CHECK-LABEL: @negative_extract_not_elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
+func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 {
+// `test.increment` is not an elemewise op.
+// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
+// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
+// CHECK: return %[[RES]] : i64
+ %0 = test.increment %arg0: vector<4xi64>
+ %1 = vector.extract %0[1] : i64 from vector<4xi64>
+ return %1 : i64
+}
+
+// CHECK-LABEL: @negative_extract_vec_fma
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>)
+func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 {
+// `vector.fma` doesn't suppport scalars.
+// CHECK: %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32>
+// CHECK: %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32>
+// CHECK: return %[[RES]] : f32
+ %0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32>
+ %1 = vector.extract %0[1] : f32 from vector<4xf32>
+ return %1 : f32
+}
More information about the Mlir-commits
mailing list