[Mlir-commits] [mlir] [mlir][vector] Make ReorderElementwiseOpsOnBroadcast support vector.splat (PR #66596)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Sep 17 06:32:28 PDT 2023


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/66596

Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcastinvg op
could be either `vector.broadcast` (alrady supported) as well as
`vector.splat` (support added in this patch).


>From 37c78fece46d7e5de63e6241e32d10e39255a4eb Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 17 Sep 2023 13:14:13 +0000
Subject: [PATCH] [mlir][vector] Make ReorderElementwiseOpsOnBroadcast support
 vector.splat

Extend `ReorderElementwiseOpsOnBroadcast` so that the broadcastinvg op
could be either `vector.broadcast` (alrady supported) as well as
`vector.splat` (support added in this patch).
---
 .../Vector/Transforms/VectorTransforms.cpp    | 40 ++++++++++++-------
 .../Dialect/Vector/sink-vector-broadcast.mlir | 39 +++++++++++++++---
 2 files changed, 59 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 207df69929c1c9f..b2a5aef5ee62d0f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -880,7 +880,7 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
   std::function<bool(BitCastOp)> controlFn;
 };
 
-/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
+/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
 /// ```
 /// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
 /// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -891,6 +891,9 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
 /// %r = arith.addi %arg0, %arg1 : index
 /// %b = vector.broadcast %r : index to vector<1x4xindex>
 /// ```
+///
+/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
+/// ops.
 struct ReorderElementwiseOpsOnBroadcast final
     : public OpTraitRewritePattern<OpTrait::Elementwise> {
   using OpTraitRewritePattern::OpTraitRewritePattern;
@@ -903,35 +906,42 @@ struct ReorderElementwiseOpsOnBroadcast final
     if (!OpTrait::hasElementwiseMappableTraits(op))
       return failure();
 
-    // Get the type of the first operand
-    auto firstBcast = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
-    if (!firstBcast)
+    // Get the type of the lhs operand
+    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
+    if (!lhsBcastOrSplat ||
+        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
       return failure();
-    auto firstOpType = firstBcast.getOperand().getType();
+    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
 
-    // Make sure that operands are "broadcast"ed from identical (scalar or
-    // vector) types. That indicates that it's safe to skip the broadcasting of
-    // operands.
-    if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) {
+    // Make sure that all operands are broadcast from identical types:
+    //  * scalar (`vector.broadcast` + `vector.splat`), or
+    //  * vector (`vector.broadcast`).
+    // Otherwise the re-ordering wouldn't be safe.
+    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
           auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          return (bcast && (bcast.getOperand().getType() == firstOpType));
+          if (bcast)
+            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
+          auto splat = val.getDefiningOp<vector::SplatOp>();
+          if (splat)
+            return (splat.getOperand().getType() == lhsBcastOrSplatType);
+          return false;
         })) {
       return failure();
     }
 
-    // Collect the source values
+    // Collect the source values before broadcasting
     SmallVector<Value> srcValues;
     srcValues.reserve(op->getNumOperands());
-
     for (Value operand : op->getOperands()) {
-      srcValues.push_back(
-          operand.getDefiningOp<vector::BroadcastOp>().getOperand());
+      srcValues.push_back(operand.getDefiningOp()->getOperand(0));
     }
 
+    // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        firstOpType, op->getAttrs());
+                        lhsBcastOrSplatType, op->getAttrs());
 
+    // Replace the original Op with the elementwise Op
     auto vectorType = op->getResultTypes()[0];
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         op, vectorType, elementwiseOp->getResults());
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
index fcf9815f6f6f1d1..d9d2f44e6f16c1f 100644
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
@@ -1,13 +1,12 @@
 // RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
 
-// CHECK-LABEL:   func.func @broadcast_scalar(
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast(
 // CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
 // CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
 // CHECK:           return %[[BCAST]] : vector<1x4xindex>
-// CHECK:         }
 
-func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
   %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
   %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
   %2 = arith.addi %0, %1 : vector<1x4xindex>
@@ -16,13 +15,27 @@ func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> {
 
 // -----
 
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
+  %0 = vector.splat %arg1 : vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @broadcast_vector(
 // CHECK-SAME:      %[[ARG_0:.*]]: vector<4xf32>,
 // CHECK-SAME:      %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
 // CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
 // CHECK:           return %[[BCAST]] : vector<3x4xf32>
-// CHECK:         }
 
 func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
   %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
@@ -30,6 +43,23 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
   %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
   return %2 : vector<3x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_vec(
+// CHECK-SAME:       %[[ARG1:.*]]: index,
+// CHECK-SAME:       %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
+// CHECK:            %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
+// CHECK:            %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
+// CHECK:            %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
+// CHECK:            return %[[ADD]] : vector<1x4xindex>
+func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
+  %0 = vector.splat %arg1 : vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
 // -----
 
 // CHECK-LABEL:   func.func @broadcast_vector_and_scalar(
@@ -38,7 +68,6 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
 // CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
 // CHECK:           return %[[ADD]] : vector<4xi32>
-// CHECK:         }
 
 func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
   %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>



More information about the Mlir-commits mailing list