[Mlir-commits] [mlir] 330a7e1 - [mlir][Vector] Make elementwise-on-broadcast sinking handle splat consts (#150867)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 09:40:53 PDT 2025


Author: Krzysztof Drewniak
Date: 2025-07-29T09:40:49-07:00
New Revision: 330a7e1136f536bf7cd642e460734d0bd6e0d0bb

URL: https://github.com/llvm/llvm-project/commit/330a7e1136f536bf7cd642e460734d0bd6e0d0bb
DIFF: https://github.com/llvm/llvm-project/commit/330a7e1136f536bf7cd642e460734d0bd6e0d0bb.diff

LOG: [mlir][Vector] Make elementwise-on-broadcast sinking handle splat consts (#150867)

There is a pattern that rewrites
elementwise_op(broadcast(x1 : T to U), broadcast(x2 : T to U), ...) to
broadcast(elementwise_op(x1, x2, ...) : T to U).

This pattern did not, however, account for the case where a broadcast
constant is represented as a SplatElementsAttr, which can safely be
reshaped or scalarized but is not a `vector.broadcast` or `vector.splat`
operation.

This patch fixes this oversight, prenting premature broadcasting.

This did result in the need to update some linalg dialect tests, which
now feature a less-broadcast computation and/or more constant folding.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
    mlir/test/Dialect/Vector/vector-sink.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fef904fa..c51c7b7270fae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1005,26 +1005,39 @@ struct ReorderElementwiseOpsOnBroadcast final
           "might be a scalar");
     }
 
-    // Get the type of the lhs operand
-    auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
-    if (!lhsBcastOrSplat ||
-        !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+    // Get the type of the first non-constant operand
+    Operation *firstBroadcastOrSplat = nullptr;
+    for (Value operand : op->getOperands()) {
+      Operation *definingOp = operand.getDefiningOp();
+      if (!definingOp)
+        return failure();
+      if (definingOp->hasTrait<OpTrait::ConstantLike>())
+        continue;
+      if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
+        return failure();
+      firstBroadcastOrSplat = definingOp;
+      break;
+    }
+    if (!firstBroadcastOrSplat)
       return failure();
-    auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+    Type firstBroadcastOrSplatType =
+        firstBroadcastOrSplat->getOperand(0).getType();
 
     // Make sure that all operands are broadcast from identical types:
     //  * scalar (`vector.broadcast` + `vector.splat`), or
     //  * vector (`vector.broadcast`).
     // Otherwise the re-ordering wouldn't be safe.
-    if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
-          auto bcast = val.getDefiningOp<vector::BroadcastOp>();
-          if (bcast)
-            return (bcast.getOperand().getType() == lhsBcastOrSplatType);
-          auto splat = val.getDefiningOp<vector::SplatOp>();
-          if (splat)
-            return (splat.getOperand().getType() == lhsBcastOrSplatType);
-          return false;
-        })) {
+    if (!llvm::all_of(
+            op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
+              if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+                return (bcastOp.getOperand().getType() ==
+                        firstBroadcastOrSplatType);
+              if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+                return (splatOp.getOperand().getType() ==
+                        firstBroadcastOrSplatType);
+              SplatElementsAttr splatConst;
+              return matchPattern(val, m_Constant(&splatConst));
+            })) {
       return failure();
     }
 
@@ -1032,13 +1045,28 @@ struct ReorderElementwiseOpsOnBroadcast final
     SmallVector<Value> srcValues;
     srcValues.reserve(op->getNumOperands());
     for (Value operand : op->getOperands()) {
-      srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+      SplatElementsAttr splatConst;
+      if (matchPattern(operand, m_Constant(&splatConst))) {
+        Attribute newConst;
+        if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
+          newConst = splatConst.resizeSplat(shapedTy);
+        } else {
+          newConst = splatConst.getSplatValue<Attribute>();
+        }
+        Operation *newConstOp =
+            operand.getDefiningOp()->getDialect()->materializeConstant(
+                rewriter, newConst, firstBroadcastOrSplatType,
+                operand.getLoc());
+        srcValues.push_back(newConstOp->getResult(0));
+      } else {
+        srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+      }
     }
 
     // Create the "elementwise" Op
     Operation *elementwiseOp =
         rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
-                        lhsBcastOrSplatType, op->getAttrs());
+                        firstBroadcastOrSplatType, op->getAttrs());
 
     // Replace the original Op with the elementwise Op
     auto vectorType = op->getResultTypes()[0];

diff  --git a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
index c3ee8929dc3f3..d7722eac2b91f 100644
--- a/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/extract-with-patterns.mlir
@@ -230,18 +230,17 @@ func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>,
 // CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG: %[[PV:.*]] = ub.poison : i32
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<3> : vector<4x3xindex>
 // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
 // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
 // CHECK:    %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
 // CHECK:    %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[PV]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
 // CHECK:    %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK:    %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
 // CHECK:    %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
-// CHECK:    %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
-// CHECK:    %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
-// CHECK:    %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
-// CHECK:    %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK:    %[[MULI:.*]] = arith.muli %[[CAST]], %[[CST]] : vector<4x3xindex>
+// CHECK:    %[[ADDI:.*]] = arith.addi %[[CAST_1]], %[[MULI]] : vector<4x3xindex>
+// CHECK:    %[[B:.*]] = vector.broadcast %[[ADDI]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK:    %[[T:.*]] = vector.transpose %[[B]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
 // CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
 // CHECK:    vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
 
@@ -270,20 +269,16 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
 // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
 // CHECK-SAME: %[[ARG1:.*]]: index
 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
-// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
 // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
-// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[0, 98304, 196608, 294912, 393216, 491520, 589824, 688128]> : vector<8xindex>
 // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
-// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
 // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
-// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
-// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
-// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_1]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[B1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
 // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
 // CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
-// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_0]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
 // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
 
 // -----
@@ -309,15 +304,13 @@ func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>)
 
 // CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
 // CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
-// CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 128, 256, 384, 512, 640, 768, 896]> : vector<8xindex>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
-// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
 // CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
 // CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
-// CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
-// CHECK:           %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
 // CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
 // CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
 // CHECK:           return %[[RES]] : tensor<8x1xf32>
@@ -420,12 +413,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<16> : vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant dense<16> : vector<4xindex>
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_9]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:           %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_7]] : vector<1x4xindex>
-// CHECK:           %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_7]] : vector<1x4xindex>
+// CHECK:           %[[VAL_10:.*]] = arith.muli %[[VAL_9]], %[[VAL_7]] : vector<4xindex>
+// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_7]] : vector<4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : vector<4xindex> to vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_12]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
 // CHECK:           %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_14]] : tensor<1x4xf32>
@@ -450,14 +443,12 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_gather(%arg0: tensor<80x16xf32
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_with_maxsi_gather(
 // CHECK-SAME:                                                             %[[VAL_0:.*]]: tensor<80x16xf32>,
 // CHECK-SAME:                                                             %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<1264> : vector<1x4xindex>
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[1264, 1265, 1266, 1267]> : vector<4xindex>
 // CHECK-DAG:       %[[VAL_4:.*]] = arith.constant dense<true> : vector<1x4xi1>
 // CHECK-DAG:       %[[VAL_5:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
 // CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = vector.broadcast %[[VAL_2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:           %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : vector<1x4xindex>
-// CHECK:           %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_8]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {{\[}}%[[VAL_7]]], %[[VAL_4]], %[[VAL_5]] : tensor<80x16xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32>
 // CHECK:           %[[VAL_10:.*]] = vector.transfer_write %[[VAL_9]], %[[VAL_1]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_10]] : tensor<1x4xf32>
 // CHECK:         }
@@ -519,13 +510,13 @@ func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1
 // CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]
 // CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]
 // CHECK-SAME:    %[[ARG2:[0-9a-zA-Z]*]]
-// CHECK-DAG:    %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG:    %[[C3:.+]] = arith.constant 3 : index
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK-DAG:    %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
 // CHECK-DAG:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
 // CHECK-DAG:    %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
-// CHECK:        %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
-// CHECK:        %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK:        %[[T0:.+]] = arith.muli %[[ARG2]], %[[C3]] : index
+// CHECK:        %[[T1:.+]] = vector.broadcast %[[T0]] : index to vector<1x1x3xindex>
 // CHECK:        %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
 // CHECK:        %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
 // CHECK:        %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]

diff  --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index b826cdca134e6..f8638ab843ecb 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -257,6 +257,70 @@ func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
   return %r : vector<2x[4]xi32>
 }
 
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[NEW_CST]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<2> : vector<1x4xindex>
+  %2 = arith.addi %0, %cst : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_and_splat_const_const_first(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant 2 : index
+// CHECK:           %[[SUB:.*]] = arith.subi %[[NEW_CST]], %[[ARG_0]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[SUB]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_and_splat_const_const_first(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<2> : vector<1x4xindex>
+  %2 = arith.subi %cst, %0 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector_and_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[NEW_CST:.*]] = arith.constant dense<2.000000e+00> : vector<4xf32>
+// CHECK:           %[[ADD:.*]] = arith.mulf %[[ARG_0]], %[[NEW_CST]] : vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector_and_splat_const(%arg0: vector<4xf32>) -> vector<3x4xf32> {
+  %0 = vector.broadcast %arg0 : vector<4xf32> to vector<3x4xf32>
+  %cst = arith.constant dense<2.000000e+00> : vector<3x4xf32>
+  %2 = arith.mulf %0, %cst : vector<3x4xf32>
+  return %2 : vector<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_broadcast_with_non_splat_const(
+// CHECK-SAME:     %[[ARG_0:.*]]: index) -> vector<1x4xindex> {
+// CHECK-DAG:       %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : index to vector<1x4xindex>
+// CHECK-DAG:       %[[CST:.*]] = arith.constant dense<{{\[}}[0, 1, 2, 3]]> : vector<1x4xindex>
+// CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<1x4xindex>
+// CHECK:           return %[[ADD]] : vector<1x4xindex>
+
+func.func @negative_broadcast_with_non_splat_const(%arg0: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<1x4xindex>
+  %cst = arith.constant dense<[[0, 1, 2, 3]]> : vector<1x4xindex>
+  %2 = arith.addi %0, %cst : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
 //===----------------------------------------------------------------------===//
 // [Pattern: ReorderElementwiseOpsOnTranspose]
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list