[Mlir-commits] [mlir] 501674d - [mlir][Vector] Further fix to avoid infinite loop in InnerOuterDimReductionConversion

Hanhan Wang llvmlistbot at llvm.org
Wed Dec 15 13:55:17 PST 2021


Author: Hanhan Wang
Date: 2021-12-15T13:54:15-08:00
New Revision: 501674dc3b14277837d0c279cde5868135e7649c

URL: https://github.com/llvm/llvm-project/commit/501674dc3b14277837d0c279cde5868135e7649c
DIFF: https://github.com/llvm/llvm-project/commit/501674dc3b14277837d0c279cde5868135e7649c.diff

LOG: [mlir][Vector] Further fix to avoid infinite loop in InnerOuterDimReductionConversion

If all the dims are reduction dims, it is already in inner-most/outer-most
reduction form.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D115820

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
    mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 9e4edda64c894..965b29257ee65 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -41,9 +41,6 @@ class InnerOuterDimReductionConversion
     auto src = multiReductionOp.source();
     auto loc = multiReductionOp.getLoc();
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
-    // If the rank is less than or equal to 1, there is nothing to do.
-    if (srcRank <= 1)
-      return failure();
 
     // Separate reduction and parallel dims
     auto reductionDimsRange =
@@ -59,6 +56,9 @@ class InnerOuterDimReductionConversion
         parallelDims.push_back(i);
 
     // Add transpose only if inner-most/outer-most dimensions are not parallel
+    // and there are parallel dims.
+    if (parallelDims.empty())
+      return failure();
     if (useInnerDimsForReduction &&
         (parallelDims ==
          llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))

diff  --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 1562fedc99d83..73679ce927f15 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -163,3 +163,10 @@ func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 {
 }
 // CHECK-LABEL: func @vector_reduction_1D
 //       CHECK:   return %{{.+}}
+
+func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 {
+  %0 = vector.multi_reduction <add>, %arg0 [0, 1] : vector<2x3xf32> to f32
+  return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+//       CHECK:   return %{{.+}}


        


More information about the Mlir-commits mailing list