[Mlir-commits] [mlir] 51a5707 - [mlir][vector] Add missing canonicalisation for vector.multi_reduction

Andrzej Warzynski llvmlistbot at llvm.org
Fri Aug 11 12:07:07 PDT 2023


Author: Andrzej Warzynski
Date: 2023-08-11T19:06:23Z
New Revision: 51a57074bc63842970c4c160b05c1a7e42db7523

URL: https://github.com/llvm/llvm-project/commit/51a57074bc63842970c4c160b05c1a7e42db7523
DIFF: https://github.com/llvm/llvm-project/commit/51a57074bc63842970c4c160b05c1a7e42db7523.diff

LOG: [mlir][vector] Add missing canonicalisation for vector.multi_reduction

Make sure that when canonicalising masked `vector.multi_reduction` and
creating `arith.select` to replace the mask, scalability of the mask is
preserved.

Differential Revision: https://reviews.llvm.org/D157732

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9ff35593a79419..0d4f8952244f9d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -381,7 +381,8 @@ struct ElideUnitDimsInMultiDimReduction
     if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
       if (mask) {
         VectorType newMaskType =
-            VectorType::get(dstVecType.getShape(), rewriter.getI1Type());
+            VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
+                            dstVecType.getScalableDims());
         mask = rewriter.create<vector::ShapeCastOp>(loc, newMaskType, mask);
       }
       cast = rewriter.create<vector::ShapeCastOp>(

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 126d8dbc4c1999..be266bbc6c9ac8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -8,7 +8,6 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
   return %0 : vector<4x3xi1>
 }
-
 // -----
 
 // CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
@@ -1320,6 +1319,24 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
     return %0 : vector<5x4x20xf32>
 }
 
+// -----
+// CHECK-LABEL:   func.func @vector_multi_reduction_scalable(
+// CHECK-SAME:     %[[VAL_0:.*]]: vector<1x[4]x1xf32>,
+// CHECK-SAME:     %[[VAL_1:.*]]: vector<1x[4]xf32>,
+// CHECK-SAME:     %[[VAL_2:.*]]: vector<1x[4]x1xi1>)
+func.func @vector_multi_reduction_scalable(%source: vector<1x[4]x1xf32>,
+                                           %acc: vector<1x[4]xf32>,
+                                           %mask: vector<1x[4]x1xi1>) -> vector<1x[4]xf32> {
+// CHECK:           %[[VAL_3:.*]] = vector.shape_cast %[[VAL_2]] : vector<1x[4]x1xi1> to vector<1x[4]xi1>
+// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x[4]x1xf32> to vector<1x[4]xf32>
+// CHECK:           %[[VAL_5:.*]] = arith.addf %[[VAL_1]], %[[VAL_4]] : vector<1x[4]xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.select %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : vector<1x[4]xi1>, vector<1x[4]xf32>
+    %0 = vector.mask %mask { vector.multi_reduction <add>, %source, %acc [2] : vector<1x[4]x1xf32> to vector<1x[4]xf32> } :
+          vector<1x[4]x1xi1> -> vector<1x[4]xf32>
+
+    return %0 : vector<1x[4]xf32>
+}
+
 // -----
 
 // CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions


        


More information about the Mlir-commits mailing list