[Mlir-commits] [mlir] [mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat (PR #66596)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 17 06:33:30 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
<details>
<summary>Changes</summary>
Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcasting op
could be either `vector.broadcast` (already supported) as well as
`vector.splat` (support added in this patch).
---
Full diff: https://github.com/llvm/llvm-project/pull/66596.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+25-15)
- (modified) mlir/test/Dialect/Vector/sink-vector-broadcast.mlir (+34-5)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 207df69929c1c9f..b2a5aef5ee62d0f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
std::function<bool(BitCastOp)> controlFn;
};
-/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// ```
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
+///
+/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
+/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
if (!OpTrait::hasElementwiseMappableTraits(op))
return failure();
- // Get the type of the first operand
- auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
- if (!firstBcast)
+ // Get the type of the lhs operand
+ auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
+ if (!lhsBcastOrSplat ||
+ !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
return failure();
- auto firstOpType = firstBcast.getOperand().getType();
+ auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
- // Make sure that operands are "broadcast"ed from identical (scalar or
- // vector) types. That indicates that it's safe to skip the broadcasting of
- // operands.
- if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
+ // Make sure that all operands are broadcast from identical types:
+ // * scalar (`vector.broadcast` + `vector.splat`), or
+ // * vector (`vector.broadcast`).
+ // Otherwise the re-ordering wouldn't be safe.
+ if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- return (bcast && (bcast.getOperand().getType() == firstOpType));
+ if (bcast)
+ return (bcast.getOperand().getType() == lhsBcastOrSplatType);
+ auto splat = val.getDefiningOp<vector::SplatOp>();
+ if (splat)
+ return (splat.getOperand().getType() == lhsBcastOrSplatType);
+ return false;
})) {
return failure();
}
- // Collect the source values
+ // Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
-
for (Value operand : op->getOperands()) {
- srcValues.push_back(
- operand.getDefiningOp<vector::BroadcastOp>().getOperand());
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
}
+ // Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- firstOpType, op->getAttrs());
+ lhsBcastOrSplatType, op->getAttrs());
+ // Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
op, vectorType, elementwiseOp->getResults());
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
index fcf9815f6f6f1d1..d9d2f44e6f16c1f 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -1,13 +1,12 @@
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @broadcast_scalar(
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
// CHECK: return %[[BCAST]] : vector<1x4xindex>
-// CHECK: }
-func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
%2 = arith.addi %0, %1 : vector<1x4xindex>
@@ -16,13 +15,27 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
// -----
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+ %0 = vector.splat %arg1 : vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// -----
+
// CHECK-LABEL: func.func @broadcast_vector(
// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
// CHECK: return %[[BCAST]] : vector<3x4xf32>
-// CHECK: }
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
@@ -30,6 +43,23 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
return %2 : vector<3x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
+// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
+// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
+// CHECK: return %[[ADD]] : vector<1x4xindex>
+func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
+ %0 = vector.splat %arg1 : vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
// -----
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
// CHECK: return %[[ADD]] : vector<4xi32>
-// CHECK: }
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/66596
More information about the Mlir-commits
mailing list