[Mlir-commits] [mlir] 4626bd0 - [mlir][vector] Disable folding for masked reductions

Diego Caballero llvmlistbot at llvm.org
Thu Jan 19 15:07:30 PST 2023


Author: Diego Caballero
Date: 2023-01-19T23:06:53Z
New Revision: 4626bd0b91102af234125f7f8ff0daaffb7a1fa4

URL: https://github.com/llvm/llvm-project/commit/4626bd0b91102af234125f7f8ff0daaffb7a1fa4
DIFF: https://github.com/llvm/llvm-project/commit/4626bd0b91102af234125f7f8ff0daaffb7a1fa4.diff

LOG: [mlir][vector] Disable folding for masked reductions

Reductions can't be folded into plain arith ops until we can mask
those arith ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141645

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 777133df07e39..974854e5637f5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -361,6 +361,12 @@ struct ElideUnitDimsInMultiDimReduction
 
   LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Masked reductions can't be folded until we can propagate the mask to the
+    // resulting operation.
+    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
     for (const auto &dim : enumerate(shape)) {
       if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
@@ -518,6 +524,12 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
 
   LogicalResult matchAndRewrite(ReductionOp reductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Masked reductions can't be folded until we can propagate the mask to the
+    // resulting operation.
+    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
+    if (maskableOp.isMasked())
+      return failure();
+
     if (reductionOp.getVectorType().getDimSize(0) != 1)
       return failure();
 

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 081d94e5805fc..fb890b6a0f44a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1371,6 +1371,20 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32
 
 // -----
 
+// Masked reduction can't be folded.
+
+// CHECK-LABEL: func @masked_vector_multi_reduction_unit_dimensions
+func.func @masked_vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>,
+                                                         %acc: vector<5x4x20xf32>,
+                                                         %mask: vector<5x1x4x1x20xi1>) -> vector<5x4x20xf32> {
+//       CHECK:   vector.mask %{{.*}} { vector.multi_reduction <mul>
+    %0 = vector.mask %mask { vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32> } :
+           vector<5x1x4x1x20xi1> -> vector<5x4x20xf32>
+    return %0 : vector<5x4x20xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
 //  CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x20xf32>
 func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x20xf32>) -> vector<5x1x20xf32> {
@@ -1921,6 +1935,18 @@ func.func @reduce_one_element_vector_addf(%a : vector<1xf32>, %b: f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @masked_reduce_one_element_vector_addf
+//       CHECK:   vector.mask %{{.*}} { vector.reduction <add>
+func.func @masked_reduce_one_element_vector_addf(%a: vector<1xf32>,
+                                                 %b: f32,
+                                                 %mask: vector<1xi1>) -> f32 {
+  %s = vector.mask %mask { vector.reduction <add>, %a, %b : vector<1xf32> into f32 }
+         : vector<1xi1> -> f32
+  return %s : f32
+}
+
+// -----
+
 // CHECK-LABEL: func @reduce_one_element_vector_mulf
 //  CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
 //       CHECK:   %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>


        


More information about the Mlir-commits mailing list