[Mlir-commits] [mlir] andrzej/group reorder tests (PR #102856)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 12 00:18:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

- **[mlir][vector] Add tests for `populateSinkVectorBroadcastPatterns` (1/n)**
- **[mlir][vector] Group tests for re-order patterns**


---

Patch is 32.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102856.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+4) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+17-9) 
- (removed) mlir/test/Dialect/Vector/sink-vector-broadcast.mlir (-127) 
- (modified) mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir (-112) 
- (added) mlir/test/Dialect/Vector/vector-reorder.mlir (+321) 
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+3-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 10970fd03e6eb2..2c17ffd49d1d41 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -148,6 +148,10 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
                                          PatternBenefit benefit = 1);
 
+/// Patterns that re-order transpose Ops.
+void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
+                                           PatternBenefit benefit = 1);
+
 /// Patterns that fold chained vector reductions. These patterns assume that
 /// elementwise operations (e.g., `arith.addf` with vector operands) are
 /// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 48b3abbeee7010..e2f9ca1fc75027 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3452,6 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
   if (!getDisableMultiReductionToContractPatterns())
     vector::populateVectorReductionToContractPatterns(patterns);
 
+  vector::populateReoderVectorTransposePatterns(patterns);
   vector::populateSinkVectorBroadcastPatterns(patterns);
 
   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f59a378e03512..922e3c61a77310 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
     if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
       return failure();
     if (!OpTrait::hasElementwiseMappableTraits(op))
+      return rewriter.notifyMatchFailure(
+          op, "Op doesn't have ElementwiseMappableTraits");
+    if (op->getNumOperands() == 0)
       return failure();
-    if (op->getNumOperands() == 0 ||
-        op->getResults()[0].getType() != op->getOperand(0).getType()) {
-      return failure();
-    }
-    // Avoid operations that only accept vector types, since broadcast
-    // source might be scalar types.
+    if (op->getResults()[0].getType() != op->getOperand(0).getType())
+      return rewriter.notifyMatchFailure(op,
+                                         "result and operand type mismatch");
     if (isa<vector::FMAOp>(op)) {
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op,
+          "Op only accepts vector types - not supported as broadcast source "
+          "might be a scalar");
     }
 
     // Get the type of the lhs operand
@@ -2027,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
 void mlir::vector::populateVectorReductionToContractPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<MultiReduceToContract, CombineContractBroadcast,
-               CombineContractABTranspose, CombineContractResultTranspose,
-               ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
+               CombineContractABTranspose, CombineContractResultTranspose>(
       patterns.getContext(), benefit);
 }
 
@@ -2046,6 +2048,12 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
       patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateReoderVectorTransposePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+                                                 benefit);
+}
+
 void mlir::vector::populateChainedVectorReductionFoldingPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
deleted file mode 100644
index e7863a9e8b7b78..00000000000000
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ /dev/null
@@ -1,127 +0,0 @@
-// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
-
-// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast(
-// CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
-// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
-// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
-// CHECK:           return %[[BCAST]] : vector<1x4xindex>
-
-func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
-  %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
-  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
-  %2 = arith.addi %0, %1 : vector<1x4xindex>
-  return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat(
-// CHECK-SAME:      %[[ARG1:.*]]: index,
-// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x4xindex> {
-// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
-// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
-// CHECK:           return %[[BCAST]] : vector<1x4xindex>
-func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
-  %0 = vector.splat %arg1 : vector<1x4xindex>
-  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
-  %2 = arith.addi %0, %1 : vector<1x4xindex>
-  return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL:   func.func @broadcast_vector(
-// CHECK-SAME:      %[[ARG_0:.*]]: vector<4xf32>,
-// CHECK-SAME:      %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
-// CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
-// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
-// CHECK:           return %[[BCAST]] : vector<3x4xf32>
-
-func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
-  %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
-  %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
-  %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
-  return %2 : vector<3x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL:   func.func @broadcast_scalar_and_vec(
-// CHECK-SAME:       %[[ARG1:.*]]: index,
-// CHECK-SAME:       %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
-// CHECK:            %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
-// CHECK:            %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:            %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
-// CHECK:            return %[[ADD]] : vector<1x4xindex>
-func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
-  %0 = vector.splat %arg1 : vector<1x4xindex>
-  %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
-  %2 = arith.addi %0, %1 : vector<1x4xindex>
-  return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL:   func.func @broadcast_vector_and_scalar(
-// CHECK-SAME:      %[[ARG_0:.*]]: i32,
-// CHECK-SAME:      %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
-// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
-// CHECK:           %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
-// CHECK:           return %[[ADD]] : vector<4xi32>
-
-func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
-  %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
-  %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
-  return %2 : vector<4xi32>
-}
-
-// -----
-
-#matmat_accesses = [
-  affine_map<(i, j, k) -> (i, k)>,
-  affine_map<(i, j, k) -> (k, j)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
-  indexing_maps = #matmat_accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL:   func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
-// CHECK-DAG:       %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
-// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
-// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
-// CHECK:           %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
-  %f1 = arith.constant 1.0: f32
-  %f2 = arith.constant 2.0: f32
-  %f3 = arith.constant 3.0: f32
-
-  %A = vector.broadcast %f1 : f32 to vector<2x2xf32>
-  %B = vector.broadcast %f2 : f32 to vector<2x2xf32>
-  %C = vector.broadcast %f3 : f32 to vector<2x2xf32>
-  %mm1 = vector.contract #matmat_trait %A, %B, %C
-    : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
-  return %mm1 : vector<2x2xf32>
-}
-
-// CHECK-LABEL: func.func @dont_sink_cmp(
-//       CHECK:   %[[BROADCAST:.+]] = vector.broadcast
-//       CHECK:   %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-//       CHECK:   return %[[RETURN]]
-func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
-  %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
-  %1 = arith.cmpf uno, %0, %0 : vector<1xf32>
-  return %1 : vector<1xi1>
-}
-
-// CHECK-LABEL: func.func @dont_sink_fma(
-  //     CHECK:   %[[BROADCAST:.+]] = vector.broadcast
-  //     CHECK:   %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
-  //     CHECK:   return %[[RESULT]]
-func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
-  %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
-  %1 = vector.fma %0, %0, %0 : vector<1xf32>
-  return %1 : vector<1xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 23a44b7c03f8f4..24070dbf017a58 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -245,118 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
 }
 
 
-//===----------------------------------------------------------------------===//
-// Reorder casting ops and vector ops. The casting ops have almost identical
-// pattern, so only arith.extsi op is tested.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
-  // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
-  %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
-  // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
-  %b = vector.broadcast %a : i8 to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
-  // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
-  // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
-  %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
-  %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
-  return %r : vector<2x4xi32>
-}
-
-//===----------------------------------------------------------------------===//
-// Reorder elementwise ops and vector ops.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_same_type
-//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
-//       CHECK:   return %[[T]]
-
-func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.addf %at, %bt : vector<2x4xf32>
-  return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
-//  CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
-//       CHECK:   return %[[T]]
-func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
-  %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
-  return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
-//  CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-//       CHECK:   %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
-//       CHECK:   return %[[T]]
-func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
-  %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
-  %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
-  return %r : vector<2x4xi1>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_splat_constant
-//  CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
-//       CHECK:   %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
-//       CHECK:   %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
-//       CHECK:   %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-//       CHECK:   return %[[T:.+]] : vector<6x4x2x3xf32>
-
-func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
-  %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
-  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-  %r = arith.addf %at, %b : vector<6x4x2x3xf32>
-  return %r : vector<6x4x2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_map
-//       CHECK:   vector.transpose
-//       CHECK:   vector.transpose
-//       CHECK:   arith.addf
-func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
-  %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-  %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
-  %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
-  return %r : vector<6x4x2x3xf32>
-}
-
 // -----
 
 // CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
diff --git a/mlir/test/Dialect/Vector/vector-reorder.mlir b/mlir/test/Dialect/Vector/vector-reorder.mlir
new file mode 100644
index 00000000000000..d7669ec2b54037
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-reorder.mlir
@@ -0,0 +1,321 @@
+// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
+
+//-----------------------------------------------------------------------------
+// [Pattern: ReorderElementwiseOpsOnBroadcast]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast(
+// CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
+  %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex> return %2 : vector<1x4xindex>
+}
+
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_scalable(
+// CHECK-SAME:     %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK:           return %[[BCAST]] : vector<1x[4]xindex>
+
+func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+  %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+  %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+  return %2 : vector<1x[4]xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK:           return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
+  %0 = vector.splat %arg1 : vector<1x4xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+  %2 = arith.addi %0, %1 : vector<1x4xindex>
+  return %2 : vector<1x4xindex>
+}
+
+// CHECK-LABEL:   func.func @broadcast_scalar_with_bcast_and_splat_scalable(
+// CHECK-SAME:      %[[ARG1:.*]]: index,
+// CHECK-SAME:      %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK:           %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK:           return %[[BCAST]] : vector<1x[4]xindex>
+func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+  %0 = vector.splat %arg1 : vector<1x[4]xindex>
+  %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+  %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+  return %2 : vector<1x[4]xindex>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_vector(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<4xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
+// CHECK:           %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK:           return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
+  %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
+  %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
+  %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
+  return %2 : vector<3x4xf32>
+}
+
+// CHECK-LABEL:   func.func @broadcast_vector_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
+// CHECK:           %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/102856


More information about the Mlir-commits mailing list