[Mlir-commits] [mlir] e9504c5 - [mlir][vector] Add tests for populateSinkVectorOpsPatterns (2/N) (#122338)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 15 06:54:20 PST 2025
Author: Andrzej WarzyĆski
Date: 2025-01-15T14:54:16Z
New Revision: e9504c52edd796a22a879b381f17bd8ed235bfd4
URL: https://github.com/llvm/llvm-project/commit/e9504c52edd796a22a879b381f17bd8ed235bfd4
DIFF: https://github.com/llvm/llvm-project/commit/e9504c52edd796a22a879b381f17bd8ed235bfd4.diff
LOG: [mlir][vector] Add tests for populateSinkVectorOpsPatterns (2/N) (#122338)
Adds tests for scalable vectors in:
* "vector-sink.mlir".
This test file exercises patterns included in
`populateSinkVectorOpsPatterns`:
* `ReorderElementwiseOpsOnBroadcast`,
* `ReorderCastOpsOnBroadcast`,
* `ReorderElementwiseOpsOnTranspose`.
This PR focuses on adding tests for the latter two patterns
(`ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose`).
Tests for `ReorderElementwiseOpsOnBroadcast` were added in #102286. Please
note that in PR #102856, I renamed:
* `populateSinkVectorBroadcastPatterns`, to
* `populateSinkVectorOpsPatterns`.
Added:
Modified:
mlir/test/Dialect/Vector/vector-sink.mlir
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 5a3699333265c1..7ce840575a8031 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// -----
+func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
+ // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
+ %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
+// -----
+
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
@@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
return %r : vector<2x4xi32>
}
+// -----
+
+func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+ // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
+ %b = vector.broadcast %a : i8 to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
@@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// -----
+func.func @transpose_extsi_scalable(%a : vector<[4]x2xi8>) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
+ // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
+ %b = vector.transpose %a, [1, 0]: vector<[4]x2xi8> to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
@@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2
// -----
+// CHECK-LABEL: func @transpose_elementwise_same_type_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+// CHECK: return %[[T]]
+
+func.func @transpose_elementwise_same_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.addf %at, %bt : vector<2x[4]xf32>
+ return %r : vector<2x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_
diff _operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
@@ -280,6 +325,21 @@ func.func @transpose_elementwise_
diff _operand_types(%cond: vector<4x2xi1>, %a :
// -----
+// CHECK-LABEL: func @transpose_elementwise_
diff _operand_types_scalable
+// CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
+// CHECK: return %[[T]]
+func.func @transpose_elementwise_
diff _operand_types_scalable(%cond: vector<[4]x2xi1>, %a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
+ %condt = vector.transpose %cond, [1, 0]: vector<[4]x2xi1> to vector<2x[4]xi1>
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.select %condt, %at, %bt : vector<2x[4]xi1>, vector<2x[4]xf32>
+ return %r : vector<2x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_
diff _operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
@@ -294,6 +354,20 @@ func.func @transpose_elementwise_
diff _operand_result_type(%a : vector<4x2xf32>,
// -----
+// CHECK-LABEL: func @transpose_elementwise_
diff _operand_result_type_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
+// CHECK: return %[[T]]
+func.func @transpose_elementwise_
diff _operand_result_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xi1> {
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.cmpf olt, %at, %bt : vector<2x[4]xf32>
+ return %r : vector<2x[4]xi1>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
@@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec
// -----
+// CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
+// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+// CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>
+
+func.func @transpose_elementwise_splat_constant_scalable(%a : vector<[4]x6x3x2xf32>) -> vector<6x[4]x2x3xf32> {
+ %b = arith.constant dense<5.0> : vector<6x[4]x2x3xf32>
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+ %r = arith.addf %at, %b : vector<6x[4]x2x3xf32>
+ return %r : vector<6x[4]x2x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_
diff _map
// CHECK: vector.transpose
// CHECK: vector.transpose
@@ -320,3 +410,16 @@ func.func @transpose_elementwise_
diff _map(%a : vector<4x6x3x2xf32>, %b: vector<6
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_
diff _map_scalable
+// CHECK: vector.transpose
+// CHECK: vector.transpose
+// CHECK: arith.addf
+func.func @transpose_elementwise_
diff _map_scalable(%a : vector<[4]x6x3x2xf32>, %b: vector<6x2x[4]x3xf32>) -> vector<6x[4]x2x3xf32> {
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+ %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x[4]x3xf32> to vector<6x[4]x2x3xf32>
+ %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
+ return %r : vector<6x[4]x2x3xf32>
+}
More information about the Mlir-commits
mailing list