[Mlir-commits] [mlir] [MLIR][Vector] Add a test of <2x[4]xf32> type on lowering multireduction (PR #100451)

Zhaoshi Zheng llvmlistbot at llvm.org
Thu Jul 25 09:36:16 PDT 2024


https://github.com/zhaoshiz updated https://github.com/llvm/llvm-project/pull/100451

>From f351842a8c5529f6b47b01d8447621d64bb8dbc4 Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Wed, 24 Jul 2024 12:06:18 -0700
Subject: [PATCH 1/2] [MLIR][Vector] Add a test of <2x[4]xf32> type on lowering
 multireduction

---
 .../vector-multi-reduction-lowering.mlir      | 24 +++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index f70d23a193229..fcc998b5bcb88 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -298,6 +298,30 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
 // CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
 // CHECK:          return %[[VAL_4]] : f32
 
+func.func @scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
+    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:  func.func @scalable_dim_2d(
+// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
+// CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
+// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
+// CHECK-DAG:      %[[C1_idx:.*]] = arith.constant 1 : index
+// CHECK-DAG:      %[[C0_idx:.*]] = arith.constant 0 : index
+// CHECK-DAG:      %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+// CHECK:          %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
+// CHECK:          %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[INSERT_0:.*]] = vector.insertelement %[[REDUCE_0]], %[[C0_2xf32]][%[[C0_idx]] : index] : vector<2xf32>
+// CHECK:          %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+// CHECK:          %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
+// CHECK:          %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
+// CHECK:          %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[INSERT_1:.*]] = vector.insertelement %[[REDUCE_1]], %[[INSERT_0]][%[[C1_idx]] : index] : vector<2xf32>
+// CHECK:          return %[[INSERT_1]] : vector<2xf32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

>From 7c55c38f03f66afe179488da0e9ec0e9ce2c08ee Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Thu, 25 Jul 2024 09:35:00 -0700
Subject: [PATCH 2/2] Rename functions in test for better clarity

---
 .../Vector/vector-multi-reduction-lowering.mlir      | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index fcc998b5bcb88..6e93923608cbf 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -249,11 +249,11 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 //  CHECK-SAME:   %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
 //       CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
 
-func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
+func.func private @vector_multi_reduction_non_scalable_dim(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
   %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
   return %0 : vector<8x[4]xf32>
 }
-// CHECK-LABEL:   func.func private @scalable_dims(
+// CHECK-LABEL:   func.func private @vector_multi_reduction_non_scalable_dim(
 // CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
 // CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
 // CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
@@ -282,12 +282,12 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
 // CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
 
 // Check that OneDimMultiReductionToTwoDim handles scalable dim
-func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+func.func @vector_multi_reduction_scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
     %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
     return %0 : f32
 }
 
-// CHECK-LABEL:  func.func @scalable_dim_1d(
+// CHECK-LABEL:  func.func @vector_multi_reduction_scalable_dim_1d(
 // CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
 // CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
 // CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
@@ -298,12 +298,12 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
 // CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
 // CHECK:          return %[[VAL_4]] : f32
 
-func.func @scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
+func.func @vector_multi_reduction_scalable_dim_2d(%A: vector<2x[4]xf32>, %B: vector<2xf32>, %C: vector<2x[4]xi1>) -> vector<2xf32> {
     %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [1] : vector<2x[4]xf32> to vector<2xf32> } : vector<2x[4]xi1> -> vector<2xf32>
     return %0 : vector<2xf32>
 }
 
-// CHECK-LABEL:  func.func @scalable_dim_2d(
+// CHECK-LABEL:  func.func @vector_multi_reduction_scalable_dim_2d(
 // CHECK-SAME:                                      %[[ARG_0:.*]]: vector<2x[4]xf32>,
 // CHECK-SAME:                                      %[[ARG_1:.*]]: vector<2xf32>,
 // CHECK-SAME:                                      %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {



More information about the Mlir-commits mailing list