[Mlir-commits] [mlir] 9c3d3ee - [mlir] vector.multi_reduction canonicalizes to vector.shape_cast (or
Murali Vijayaraghavan
llvmlistbot at llvm.org
Wed Oct 5 17:19:32 PDT 2022
Author: Murali Vijayaraghavan
Date: 2022-10-06T00:11:31Z
New Revision: 9c3d3eeb51b7a3f6428bab7bd46452ce18029060
URL: https://github.com/llvm/llvm-project/commit/9c3d3eeb51b7a3f6428bab7bd46452ce18029060
DIFF: https://github.com/llvm/llvm-project/commit/9c3d3eeb51b7a3f6428bab7bd46452ce18029060.diff
LOG: [mlir] vector.multi_reduction canonicalizes to vector.shape_cast (or
vector.extract, if the result is a scalar) only if all reduction
dimensions are of size 1.
Differential Revision: https://reviews.llvm.org/D135333
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 94d0ea939d58d..575dfbb9a2114 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -396,6 +396,7 @@ def Vector_MultiDimReductionOp :
let assemblyFormat =
"$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4929ca170b1d1..5e0a177e4e1c2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -309,6 +309,50 @@ LogicalResult MultiDimReductionOp::verify() {
return success();
}
+namespace {
+// Only unit dimensions that are being reduced are folded. If the dimension is
+// unit, but not reduced, it is not folded, thereby keeping the output type the
+// same. If not all dimensions which are reduced are of unit dimension, this
+// transformation does nothing. This is just a generalization of
+// ElideSingleElementReduction for ReduceOp.
+struct ElideUnitDimsInMultiDimReduction
+ : public OpRewritePattern<MultiDimReductionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
+ PatternRewriter &rewriter) const override {
+ ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
+ for (auto dim : enumerate(shape)) {
+ if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
+ return failure();
+ }
+ Location loc = reductionOp.getLoc();
+ Value acc = reductionOp.getAcc();
+ Value cast;
+ if (reductionOp.getDestType().isa<VectorType>()) {
+ cast = rewriter.create<vector::ShapeCastOp>(
+ loc, reductionOp.getDestType(), reductionOp.getSource());
+ } else {
+ // This means we are reducing all the dimensions, and all reduction
+ // dimensions are of size 1. So a simple extraction would do.
+ cast = rewriter.create<vector::ExtractOp>(
+ loc, reductionOp.getDestType(), reductionOp.getSource(),
+ rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
+ }
+
+ Value result = vector::makeArithReduction(rewriter, loc,
+ reductionOp.getKind(), acc, cast);
+ rewriter.replaceOp(reductionOp, result);
+ return success();
+ }
+};
+} // namespace
+
+void MultiDimReductionOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<ElideUnitDimsInMultiDimReduction>(context);
+}
+
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 15c78c1118096..c3d9ae24c0891 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1348,6 +1348,44 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
// -----
+// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
+func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[SOURCE]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[ACC]], %[[CAST]] : vector<5x4x20xf32>
+ %0 = vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
+
+// CHECK: return %[[RESULT]] : vector<5x4x20xf32>
+ return %0 : vector<5x4x20xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
+// CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x20xf32>
+func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x20xf32>) -> vector<5x1x20xf32> {
+// CHECK: %[[RES:.+]] = vector.multi_reduction <mul>, %[[SRC]], %[[ACCUM]] [1, 2] : vector<5x1x4x1x20xf32> to vector<5x1x20xf32>
+ %0 = vector.multi_reduction <mul>, %source, %acc [1, 2] : vector<5x1x4x1x20xf32> to vector<5x1x20xf32>
+
+// CHECK: return %[[RES]] : vector<5x1x20xf32>
+ return %0 : vector<5x1x20xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_single_elem(
+// CHECK-SAME: %[[SOURCE:.+]]: vector<1x1x1xf32>, %[[ACC:.+]]: f32
+func.func @vector_multi_reduction_unit_dimensions_single_elem(%source: vector<1x1x1xf32>, %acc: f32) -> f32 {
+// CHECK: %[[CAST:.+]] = vector.extract %[[SOURCE]][0, 0, 0] : vector<1x1x1xf32>
+// CHECK: %[[RESULT:.+]] = arith.mulf %[[ACC]], %[[CAST]] : f32
+ %0 = vector.multi_reduction <mul>, %source, %acc [0,1,2] : vector<1x1x1xf32> to f32
+
+// CHECK: return %[[RESULT]] : f32
+ return %0 : f32
+}
+
+// -----
+
// CHECK-LABEL: func @insert_strided_slice_full_range
// CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {
More information about the Mlir-commits
mailing list