[Mlir-commits] [mlir] 501674d - [mlir][Vector] Further fix to avoid infinite loop in InnerOuterDimReductionConversion
Hanhan Wang
llvmlistbot at llvm.org
Wed Dec 15 13:55:17 PST 2021
Author: Hanhan Wang
Date: 2021-12-15T13:54:15-08:00
New Revision: 501674dc3b14277837d0c279cde5868135e7649c
URL: https://github.com/llvm/llvm-project/commit/501674dc3b14277837d0c279cde5868135e7649c
DIFF: https://github.com/llvm/llvm-project/commit/501674dc3b14277837d0c279cde5868135e7649c.diff
LOG: [mlir][Vector] Further fix to avoid infinite loop in InnerOuterDimReductionConversion
If all the dims are reduction dims, it is already in inner-most/outer-most
reduction form.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D115820
Added:
Modified:
mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 9e4edda64c894..965b29257ee65 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -41,9 +41,6 @@ class InnerOuterDimReductionConversion
auto src = multiReductionOp.source();
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
- // If the rank is less than or equal to 1, there is nothing to do.
- if (srcRank <= 1)
- return failure();
// Separate reduction and parallel dims
auto reductionDimsRange =
@@ -59,6 +56,9 @@ class InnerOuterDimReductionConversion
parallelDims.push_back(i);
// Add transpose only if inner-most/outer-most dimensions are not parallel
+ // and there are parallel dims.
+ if (parallelDims.empty())
+ return failure();
if (useInnerDimsForReduction &&
(parallelDims ==
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 1562fedc99d83..73679ce927f15 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -163,3 +163,10 @@ func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 {
}
// CHECK-LABEL: func @vector_reduction_1D
// CHECK: return %{{.+}}
+
+func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 {
+ %0 = vector.multi_reduction <add>, %arg0 [0, 1] : vector<2x3xf32> to f32
+ return %0 : f32
+}
+// CHECK-LABEL: func @vector_multi_reduction_to_scalar
+// CHECK: return %{{.+}}
More information about the Mlir-commits
mailing list