[Mlir-commits] [mlir] [mlir][vector] Add tests for populateSinkVectorOpsPatterns (2/N) (PR #122338)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jan 9 10:59:24 PST 2025
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/122338
Adds tests for scalable vectors in:
* vector-sink.mlir
This test file excercises patterns grouped under
`populateSinkVectorOpsPatterns`, which includes:
* `ReorderElementwiseOpsOnBroadcast`.
* `ReorderCastOpsOnBroadcast`,
* `ReorderElementwiseOpsOnTranspose`,
Note, this is a follow-on for:
* https://github.com/llvm/llvm-project/pull/102286.
This PR adds tests for the latter two - tests
for`ReorderElementwiseOpsOnBroadcast` were added in #102286. When
refering to the previous PR, please note that in #102856 I renamed
* `populateSinkVectorBroadcastPatterns` as
* `populateSinkVectorOpsPatterns`.
>From 80ce22109ec768d0ecafe43db266b8a70046ba22 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 9 Jan 2025 18:50:30 +0000
Subject: [PATCH] [mlir][vector] Add tests for populateSinkVectorOpsPatterns
(2/N)
Adds tests for scalable vectors in:
* vector-sink.mlir
This test file excercises patterns grouped under
`populateSinkVectorOpsPatterns`, which includes:
* `ReorderElementwiseOpsOnBroadcast`.
* `ReorderCastOpsOnBroadcast`,
* `ReorderElementwiseOpsOnTranspose`,
Note, this is a follow-on for:
* https://github.com/llvm/llvm-project/pull/102286.
This PR adds tests for the latter two - tests
for`ReorderElementwiseOpsOnBroadcast` were added in #102286. When
refering to the previous PR, please note that in #102856 I renamed
* `populateSinkVectorBroadcastPatterns` as
* `populateSinkVectorOpsPatterns`.
---
mlir/test/Dialect/Vector/vector-sink.mlir | 103 ++++++++++++++++++++++
1 file changed, 103 insertions(+)
diff --git a/mlir/test/Dialect/Vector/vector-sink.mlir b/mlir/test/Dialect/Vector/vector-sink.mlir
index 5a3699333265c1..7ce840575a8031 100644
--- a/mlir/test/Dialect/Vector/vector-sink.mlir
+++ b/mlir/test/Dialect/Vector/vector-sink.mlir
@@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
// -----
+func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
+ // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
+ %b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
+// -----
+
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
@@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
return %r : vector<2x4xi32>
}
+// -----
+
+func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
+ // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
+ %b = vector.broadcast %a : i8 to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
//===----------------------------------------------------------------------===//
// [Pattern: ReorderElementwiseOpsOnTranspose]
//===----------------------------------------------------------------------===//
@@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
// -----
+func.func @transpose_extsi_scalable(%a : vector<[4]x2xi8>) -> vector<2x[4]xi32> {
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
+ // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
+ %b = vector.transpose %a, [1, 0]: vector<[4]x2xi8> to vector<2x[4]xi8>
+ %r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
+ return %r : vector<2x[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_same_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
@@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2
// -----
+// CHECK-LABEL: func @transpose_elementwise_same_type_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
+// CHECK: return %[[T]]
+
+func.func @transpose_elementwise_same_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.addf %at, %bt : vector<2x[4]xf32>
+ return %r : vector<2x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
@@ -280,6 +325,21 @@ func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a :
// -----
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_types_scalable
+// CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
+// CHECK: return %[[T]]
+func.func @transpose_elementwise_diff_operand_types_scalable(%cond: vector<[4]x2xi1>, %a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
+ %condt = vector.transpose %cond, [1, 0]: vector<[4]x2xi1> to vector<2x[4]xi1>
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.select %condt, %at, %bt : vector<2x[4]xi1>, vector<2x[4]xf32>
+ return %r : vector<2x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
@@ -294,6 +354,20 @@ func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>,
// -----
+// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
+// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
+// CHECK: return %[[T]]
+func.func @transpose_elementwise_diff_operand_result_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xi1> {
+ %at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
+ %r = arith.cmpf olt, %at, %bt : vector<2x[4]xf32>
+ return %r : vector<2x[4]xi1>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_splat_constant
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
@@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec
// -----
+// CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
+// CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
+// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
+// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
+// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+// CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>
+
+func.func @transpose_elementwise_splat_constant_scalable(%a : vector<[4]x6x3x2xf32>) -> vector<6x[4]x2x3xf32> {
+ %b = arith.constant dense<5.0> : vector<6x[4]x2x3xf32>
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+ %r = arith.addf %at, %b : vector<6x[4]x2x3xf32>
+ return %r : vector<6x[4]x2x3xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose_elementwise_diff_map
// CHECK: vector.transpose
// CHECK: vector.transpose
@@ -320,3 +410,16 @@ func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
return %r : vector<6x4x2x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @transpose_elementwise_diff_map_scalable
+// CHECK: vector.transpose
+// CHECK: vector.transpose
+// CHECK: arith.addf
+func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %b: vector<6x2x[4]x3xf32>) -> vector<6x[4]x2x3xf32> {
+ %at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
+ %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x[4]x3xf32> to vector<6x[4]x2x3xf32>
+ %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
+ return %r : vector<6x[4]x2x3xf32>
+}
More information about the Mlir-commits
mailing list