[Mlir-commits] [mlir] [MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim (PR #89978)

Zhaoshi Zheng llvmlistbot at llvm.org
Thu Apr 25 09:06:40 PDT 2024


https://github.com/zhaoshiz updated https://github.com/llvm/llvm-project/pull/89978

>From ecb7103f964e59d19115e05543ec0efc9b0252b5 Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Tue, 19 Mar 2024 09:07:12 -0700
Subject: [PATCH 1/3] [MLIR][Vector] Allow Scalable Dim in
 OneDimMultiReductionToTwoDim

To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>

(cherry picked from commit 35f95c21e7895181084cc3a14c4e70eb8d1e6eee)
---
 .../Transforms/LowerVectorMultiReduction.cpp   |  9 ++++++---
 .../vector-multi-reduction-lowering.mlir       | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f21c50c63473b..240fedf4ccb651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -438,7 +438,9 @@ struct OneDimMultiReductionToTwoDim
     auto srcVectorType = multiReductionOp.getSourceVectorType();
     auto srcShape = srcVectorType.getShape();
     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
-                                      srcVectorType.getElementType());
+                                      srcVectorType.getElementType(),
+        ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
+
     auto accType =
         VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
     assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
         loc, accType, multiReductionOp.getAcc());
     Value castMask;
     if (maskableOp.isMasked()) {
-      auto maskType = llvm::cast<ShapedType>(mask.getType());
+      auto maskType = llvm::cast<VectorType>(mask.getType());
       auto castMaskType =
           VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
-                          maskType.getElementType());
+                          maskType.getElementType(),
+          ArrayRef<bool>{false, maskType.getScalableDims().back()});
       castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
     }
 
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 22808aa7d6acc3..a5fe83118a0cf2 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -253,6 +253,7 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
   %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
   return %0 : vector<8x[4]xf32>
 }
+
 // CHECK-LABEL:   func.func private @scalable_dims(
 // CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
 // CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
@@ -281,6 +282,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
 // CHECK:           %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
 // CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
 
+// Check that OneDimMultiReductionToTwoDim handles scalable dim
+func.func private @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+    return %0 : f32
+}
+
+// CHECK-LABEL:  func.func private @scalable_dim_1d(
+// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
+// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK-DAG:      %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:      %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:          %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
+// CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
+// CHECK:          return %[[VAL_4]] : f32
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

>From a4a33e470e5ffe05168a7c01768c701742d7e6da Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Tue, 19 Mar 2024 09:07:12 -0700
Subject: [PATCH 2/3] [MLIR][Vector] Allow scalable dim in
 OneDimMultiReductionToTwoDim

To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>
---
 .../Vector/Transforms/LowerVectorMultiReduction.cpp    | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 240fedf4ccb651..ac576ed0b4f097 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -437,8 +437,8 @@ struct OneDimMultiReductionToTwoDim
     auto loc = multiReductionOp.getLoc();
     auto srcVectorType = multiReductionOp.getSourceVectorType();
     auto srcShape = srcVectorType.getShape();
-    auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
-                                      srcVectorType.getElementType(),
+    auto castedType = VectorType::get(
+        ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
         ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
 
     auto accType =
@@ -458,9 +458,9 @@ struct OneDimMultiReductionToTwoDim
     Value castMask;
     if (maskableOp.isMasked()) {
       auto maskType = llvm::cast<VectorType>(mask.getType());
-      auto castMaskType =
-          VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
-                          maskType.getElementType(),
+      auto castMaskType = VectorType::get(
+          ArrayRef<int64_t>{1, maskType.getShape().back()},
+          maskType.getElementType(),
           ArrayRef<bool>{false, maskType.getScalableDims().back()});
       castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
     }

>From ce1a1b291310fface4a3e75eebb9384fcd7e67fb Mon Sep 17 00:00:00 2001
From: Zhaoshi Zheng <zhaoshiz at quicinc.com>
Date: Tue, 19 Mar 2024 09:07:12 -0700
Subject: [PATCH 3/3] [MLIR][Vector] Allow scalable dim in
 OneDimMultiReductionToTwoDim

To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>
---
 .../test/Dialect/Vector/vector-multi-reduction-lowering.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index a5fe83118a0cf2..f70d23a1932297 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -253,7 +253,6 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
   %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
   return %0 : vector<8x[4]xf32>
 }
-
 // CHECK-LABEL:   func.func private @scalable_dims(
 // CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
 // CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
@@ -283,12 +282,12 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
 // CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
 
 // Check that OneDimMultiReductionToTwoDim handles scalable dim
-func.func private @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
     %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
     return %0 : f32
 }
 
-// CHECK-LABEL:  func.func private @scalable_dim_1d(
+// CHECK-LABEL:  func.func @scalable_dim_1d(
 // CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
 // CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
 // CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {



More information about the Mlir-commits mailing list