[Mlir-commits] [mlir] 617ca92 - Revert "Added canonicalization for vector.multi_reduction"
Murali Vijayaraghavan
llvmlistbot at llvm.org
Wed Oct 5 14:47:34 PDT 2022
Author: Murali Vijayaraghavan
Date: 2022-10-05T21:43:51Z
New Revision: 617ca92bf155da73a8552345598a920777643e53
URL: https://github.com/llvm/llvm-project/commit/617ca92bf155da73a8552345598a920777643e53
DIFF: https://github.com/llvm/llvm-project/commit/617ca92bf155da73a8552345598a920777643e53.diff
LOG: Revert "Added canonicalization for vector.multi_reduction"
This reverts commit c16f3260a9255c7d9880f72de7d856f9ceeb1866.
There's a bug in the commit creates a scalar result with `ShapeCastOp`.
Reverting till that fix is done.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 575dfbb9a2114..94d0ea939d58d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -396,7 +396,6 @@ def Vector_MultiDimReductionOp :
let assemblyFormat =
"$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
let hasFolder = 1;
- let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 751a592b60aa9..4929ca170b1d1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -309,40 +309,6 @@ LogicalResult MultiDimReductionOp::verify() {
return success();
}
-namespace {
-// Only unit dimensions that are being reduced are folded. If the dimension is
-// unit, but not reduced, it is not folded, thereby keeping the output type the
-// same. If not all dimensions which are reduced are of unit dimension, this
-// transformation does nothing. This is just a generalization of
-// ElideSingleElementReduction for ReduceOp.
-struct ElideUnitDimsInMultiDimReduction
- : public OpRewritePattern<MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
- PatternRewriter &rewriter) const override {
- ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
- for (auto dim : enumerate(shape)) {
- if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
- return failure();
- }
- Location loc = reductionOp.getLoc();
- Value acc = reductionOp.getAcc();
- Value cast = rewriter.create<vector::ShapeCastOp>(
- loc, reductionOp.getDestType(), reductionOp.getSource());
- Value result = vector::makeArithReduction(rewriter, loc,
- reductionOp.getKind(), acc, cast);
- rewriter.replaceOp(reductionOp, result);
- return success();
- }
-};
-} // namespace
-
-void MultiDimReductionOp::getCanonicalizationPatterns(
- RewritePatternSet &results, MLIRContext *context) {
- results.add<ElideUnitDimsInMultiDimReduction>(context);
-}
-
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 34d292f6d4c56..15c78c1118096 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1348,31 +1348,6 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
// -----
-// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
-// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
-func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
-// CHECK: %[[CAST:.+]] = vector.shape_cast %[[SOURCE]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
-// CHECK: %[[RESULT:.+]] = arith.mulf %[[ACC]], %[[CAST]] : vector<5x4x20xf32>
- %0 = vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
-
-// CHECK: return %[[RESULT]] : vector<5x4x20xf32>
- return %0 : vector<5x4x20xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
-// CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x1x20xf32>
-func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x1x20xf32>) -> vector<5x1x1x20xf32> {
-// CHECK: %[[RES:.+]] = vector.multi_reduction <mul>, %[[SRC]], %[[ACCUM]] [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32>
- %0 = vector.multi_reduction <mul>, %source, %acc [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32>
-
-// CHECK: return %[[RES]] : vector<5x1x1x20xf32>
- return %0 : vector<5x1x1x20xf32>
-}
-
-// -----
-
// CHECK-LABEL: func @insert_strided_slice_full_range
// CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {
More information about the Mlir-commits
mailing list