[Mlir-commits] [mlir] andrzej/group reorder tests (PR #102856)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 12 00:18:14 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][vector] Add tests for `populateSinkVectorBroadcastPatterns` (1/n)**
- **[mlir][vector] Group tests for re-order patterns**
---
Patch is 32.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102856.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+4)
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+17-9)
- (removed) mlir/test/Dialect/Vector/sink-vector-broadcast.mlir (-127)
- (modified) mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir (-112)
- (added) mlir/test/Dialect/Vector/vector-reorder.mlir (+321)
- (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+3-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 10970fd03e6eb2..2c17ffd49d1d41 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -148,6 +148,10 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Patterns that re-order transpose Ops.
+void populateReoderVectorTransposePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Patterns that fold chained vector reductions. These patterns assume that
/// elementwise operations (e.g., `arith.addf` with vector operands) are
/// cheaper than vector reduction.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 48b3abbeee7010..e2f9ca1fc75027 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3452,6 +3452,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
if (!getDisableMultiReductionToContractPatterns())
vector::populateVectorReductionToContractPatterns(patterns);
+ vector::populateReoderVectorTransposePatterns(patterns);
vector::populateSinkVectorBroadcastPatterns(patterns);
patterns.add<linalg::LinalgCopyVTRForwardingPattern,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f59a378e03512..922e3c61a77310 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
+ return rewriter.notifyMatchFailure(
+ op, "Op doesn't have ElementwiseMappableTraits");
+ if (op->getNumOperands() == 0)
return failure();
- if (op->getNumOperands() == 0 ||
- op->getResults()[0].getType() != op->getOperand(0).getType()) {
- return failure();
- }
- // Avoid operations that only accept vector types, since broadcast
- // source might be scalar types.
+ if (op->getResults()[0].getType() != op->getOperand(0).getType())
+ return rewriter.notifyMatchFailure(op,
+ "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op,
+ "Op only accepts vector types - not supported as broadcast source "
+ "might be a scalar");
}
// Get the type of the lhs operand
@@ -2027,8 +2030,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
- CombineContractABTranspose, CombineContractResultTranspose,
- ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>(
+ CombineContractABTranspose, CombineContractResultTranspose>(
patterns.getContext(), benefit);
}
@@ -2046,6 +2048,12 @@ void mlir::vector::populateSinkVectorBroadcastPatterns(
patterns.getContext(), benefit);
}
+void mlir::vector::populateReoderVectorTransposePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ReorderElementwiseOpsOnTranspose>(patterns.getContext(),
+ benefit);
+}
+
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ChainedReduction>(patterns.getContext(), benefit);
diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
deleted file mode 100644
index e7863a9e8b7b78..00000000000000
--- a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir
+++ /dev/null
@@ -1,127 +0,0 @@
-// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
-
-// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
-// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
-// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
-// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
-// CHECK: return %[[BCAST]] : vector<1x4xindex>
-
-func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
- %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
- %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
- %2 = arith.addi %0, %1 : vector<1x4xindex>
- return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
-// CHECK-SAME: %[[ARG1:.*]]: index,
-// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
-// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
-// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
-// CHECK: return %[[BCAST]] : vector<1x4xindex>
-func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
- %0 = vector.splat %arg1 : vector<1x4xindex>
- %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
- %2 = arith.addi %0, %1 : vector<1x4xindex>
- return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @broadcast_vector(
-// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
-// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
-// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
-// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
-// CHECK: return %[[BCAST]] : vector<3x4xf32>
-
-func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
- %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
- %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
- %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
- return %2 : vector<3x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
-// CHECK-SAME: %[[ARG1:.*]]: index,
-// CHECK-SAME: %[[ARG2:.*]]: vector<4xindex>) -> vector<1x4xindex> {
-// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x4xindex>
-// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
-// CHECK: return %[[ADD]] : vector<1x4xindex>
-func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
- %0 = vector.splat %arg1 : vector<1x4xindex>
- %1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
- %2 = arith.addi %0, %1 : vector<1x4xindex>
- return %2 : vector<1x4xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
-// CHECK-SAME: %[[ARG_0:.*]]: i32,
-// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> {
-// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32>
-// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
-// CHECK: return %[[ADD]] : vector<4xi32>
-
-func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
- %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
- %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
- return %2 : vector<4xi32>
-}
-
-// -----
-
-#matmat_accesses = [
- affine_map<(i, j, k) -> (i, k)>,
- affine_map<(i, j, k) -> (k, j)>,
- affine_map<(i, j, k) -> (i, j)>
-]
-#matmat_trait = {
- indexing_maps = #matmat_accesses,
- iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
-// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
-// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
-// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
-// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
- %f1 = arith.constant 1.0: f32
- %f2 = arith.constant 2.0: f32
- %f3 = arith.constant 3.0: f32
-
- %A = vector.broadcast %f1 : f32 to vector<2x2xf32>
- %B = vector.broadcast %f2 : f32 to vector<2x2xf32>
- %C = vector.broadcast %f3 : f32 to vector<2x2xf32>
- %mm1 = vector.contract #matmat_trait %A, %B, %C
- : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
-
- return %mm1 : vector<2x2xf32>
-}
-
-// CHECK-LABEL: func.func @dont_sink_cmp(
-// CHECK: %[[BROADCAST:.+]] = vector.broadcast
-// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
-// CHECK: return %[[RETURN]]
-func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
- %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
- %1 = arith.cmpf uno, %0, %0 : vector<1xf32>
- return %1 : vector<1xi1>
-}
-
-// CHECK-LABEL: func.func @dont_sink_fma(
- // CHECK: %[[BROADCAST:.+]] = vector.broadcast
- // CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
- // CHECK: return %[[RESULT]]
-func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
- %0 = vector.broadcast %arg0 : f32 to vector<1xf32>
- %1 = vector.fma %0, %0, %0 : vector<1xf32>
- return %1 : vector<1xf32>
-}
diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
index 23a44b7c03f8f4..24070dbf017a58 100644
--- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
+++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir
@@ -245,118 +245,6 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
}
-//===----------------------------------------------------------------------===//
-// Reorder casting ops and vector ops. The casting ops have almost identical
-// pattern, so only arith.extsi op is tested.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
- // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32>
- // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32>
- %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8>
- %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
- return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
- // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
- // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
- %b = vector.broadcast %a : i8 to vector<2x4xi8>
- %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
- return %r : vector<2x4xi32>
-}
-
-// -----
-
-func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
- // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
- // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
- %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8>
- %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32>
- return %r : vector<2x4xi32>
-}
-
-//===----------------------------------------------------------------------===//
-// Reorder elementwise ops and vector ops.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_same_type
-// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
-// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
-// CHECK: return %[[T]]
-
-func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
- %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %r = arith.addf %at, %bt : vector<2x4xf32>
- return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
-// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
-// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32>
-// CHECK: return %[[T]]
-func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> {
- %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1>
- %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32>
- return %r : vector<2x4xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
-// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
-// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
-// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1>
-// CHECK: return %[[T]]
-func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> {
- %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
- %r = arith.cmpf olt, %at, %bt : vector<2x4xf32>
- return %r : vector<2x4xi1>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_splat_constant
-// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
-// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
-// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32>
-// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
-// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32>
-
-func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> {
- %b = arith.constant dense<5.0> : vector<6x4x2x3xf32>
- %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
- %r = arith.addf %at, %b : vector<6x4x2x3xf32>
- return %r : vector<6x4x2x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose_elementwise_diff_map
-// CHECK: vector.transpose
-// CHECK: vector.transpose
-// CHECK: arith.addf
-func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> {
- %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32>
- %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32>
- %r = arith.addf %at, %bt : vector<6x4x2x3xf32>
- return %r : vector<6x4x2x3xf32>
-}
-
// -----
// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>
diff --git a/mlir/test/Dialect/Vector/vector-reorder.mlir b/mlir/test/Dialect/Vector/vector-reorder.mlir
new file mode 100644
index 00000000000000..d7669ec2b54037
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-reorder.mlir
@@ -0,0 +1,321 @@
+// RUN: mlir-opt %s -test-vector-reorder-patterns -split-input-file | FileCheck %s
+
+//-----------------------------------------------------------------------------
+// [Pattern: ReorderElementwiseOpsOnBroadcast]
+//-----------------------------------------------------------------------------
+
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+
+func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
+ %0 = vector.broadcast %arg1 : index to vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex> return %2 : vector<1x4xindex>
+}
+
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
+
+func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+ %0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+ %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+ return %2 : vector<1x[4]xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x4xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
+// CHECK: return %[[BCAST]] : vector<1x4xindex>
+func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
+ %0 = vector.splat %arg1 : vector<1x4xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x4xindex>
+ %2 = arith.addi %0, %1 : vector<1x4xindex>
+ return %2 : vector<1x4xindex>
+}
+
+// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
+// CHECK-SAME: %[[ARG1:.*]]: index,
+// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
+// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
+// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
+func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
+ %0 = vector.splat %arg1 : vector<1x[4]xindex>
+ %1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
+ %2 = arith.addi %0, %1 : vector<1x[4]xindex>
+ return %2 : vector<1x[4]xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_vector(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> {
+// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
+// CHECK: return %[[BCAST]] : vector<3x4xf32>
+
+func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
+ %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
+ %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
+ %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
+ return %2 : vector<3x4xf32>
+}
+
+// CHECK-LABEL: func.func @broadcast_vector_scalable(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
+// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/102856
More information about the Mlir-commits
mailing list