[Mlir-commits] [mlir] 77f4c91 - [mlir][vector] Allow transposing multi_reduction when the parallel dim is in the middle
Benjamin Kramer
llvmlistbot at llvm.org
Thu Jan 26 09:06:52 PST 2023
Author: Benjamin Kramer
Date: 2023-01-26T18:06:42+01:00
New Revision: 77f4c91c5e052f43245555fffbb0649191e3e32f
URL: https://github.com/llvm/llvm-project/commit/77f4c91c5e052f43245555fffbb0649191e3e32f
DIFF: https://github.com/llvm/llvm-project/commit/77f4c91c5e052f43245555fffbb0649191e3e32f.diff
LOG: [mlir][vector] Allow transposing multi_reduction when the parallel dim is in the middle
The check for the outer lowering wasn't quite right.
Differential Revision: https://reviews.llvm.org/D142483
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
index e89059cb9390a..117fdcb84c809 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
@@ -77,8 +77,9 @@ class InnerOuterDimReductionConversion
return failure();
if (!useInnerDimsForReduction &&
- (parallelDims !=
- llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
+ (parallelDims == llvm::to_vector<4>(llvm::seq<int64_t>(
+ reductionDims.size(),
+ parallelDims.size() + reductionDims.size()))))
return failure();
SmallVector<int64_t, 4> indices;
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index ee4ab7a1c5c8f..5647089d2ed5b 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -234,3 +234,13 @@ func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1
// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
+// -----
+
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
+// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 8a8bf86bfd38b..9f22972a09e5c 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -162,6 +162,15 @@ func.func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32>
// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32>
+func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
+// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
+// CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
// This test is mainly to catch a bug that running
// `InnerOuterDimReductionConversion` on this function results in an
// infinite loop. So just check that some value is returned.
More information about the Mlir-commits
mailing list