[Mlir-commits] [mlir] c16f326 - Added canonicalization for vector.multi_reduction

Murali Vijayaraghavan llvmlistbot at llvm.org
Wed Oct 5 11:44:05 PDT 2022


Author: Murali Vijayaraghavan
Date: 2022-10-05T18:43:33Z
New Revision: c16f3260a9255c7d9880f72de7d856f9ceeb1866

URL: https://github.com/llvm/llvm-project/commit/c16f3260a9255c7d9880f72de7d856f9ceeb1866
DIFF: https://github.com/llvm/llvm-project/commit/c16f3260a9255c7d9880f72de7d856f9ceeb1866.diff

LOG: Added canonicalization for vector.multi_reduction

If there are reductions only along unit dimensions, then they are folded

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D134996

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 94d0ea939d58d..575dfbb9a2114 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -396,6 +396,7 @@ def Vector_MultiDimReductionOp :
   let assemblyFormat =
     "$kind `,` $source `,` $acc attr-dict $reduction_dims `:` type($source) `to` type($dest)";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4929ca170b1d1..751a592b60aa9 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -309,6 +309,40 @@ LogicalResult MultiDimReductionOp::verify() {
   return success();
 }
 
+namespace {
+// Only unit dimensions that are being reduced are folded. If the dimension is
+// unit, but not reduced, it is not folded, thereby keeping the output type the
+// same. If not all dimensions which are reduced are of unit dimension, this
+// transformation does nothing. This is just a generalization of
+// ElideSingleElementReduction for ReduceOp.
+struct ElideUnitDimsInMultiDimReduction
+    : public OpRewritePattern<MultiDimReductionOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
+                                PatternRewriter &rewriter) const override {
+    ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
+    for (auto dim : enumerate(shape)) {
+      if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
+        return failure();
+    }
+    Location loc = reductionOp.getLoc();
+    Value acc = reductionOp.getAcc();
+    Value cast = rewriter.create<vector::ShapeCastOp>(
+        loc, reductionOp.getDestType(), reductionOp.getSource());
+    Value result = vector::makeArithReduction(rewriter, loc,
+                                              reductionOp.getKind(), acc, cast);
+    rewriter.replaceOp(reductionOp, result);
+    return success();
+  }
+};
+} // namespace
+
+void MultiDimReductionOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<ElideUnitDimsInMultiDimReduction>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ReductionOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 15c78c1118096..34d292f6d4c56 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1348,6 +1348,31 @@ func.func @vector_multi_reduction_single_parallel(%arg0: vector<2xf32>, %acc: ve
 
 // -----
 
+// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions(
+//  CHECK-SAME: %[[SOURCE:.+]]: vector<5x1x4x1x20xf32>, %[[ACC:.+]]: vector<5x4x20xf32>
+func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x4x20xf32>) -> vector<5x4x20xf32> {
+//       CHECK:   %[[CAST:.+]] = vector.shape_cast  %[[SOURCE]] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
+//       CHECK:   %[[RESULT:.+]] = arith.mulf  %[[ACC]], %[[CAST]] : vector<5x4x20xf32>
+    %0 = vector.multi_reduction <mul>, %source, %acc [1, 3] : vector<5x1x4x1x20xf32> to vector<5x4x20xf32>
+
+//       CHECK:     return %[[RESULT]] : vector<5x4x20xf32>
+    return %0 : vector<5x4x20xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vector_multi_reduction_unit_dimensions_fail(
+//  CHECK-SAME: %[[SRC:.+]]: vector<5x1x4x1x20xf32>, %[[ACCUM:.+]]: vector<5x1x1x20xf32>
+func.func @vector_multi_reduction_unit_dimensions_fail(%source: vector<5x1x4x1x20xf32>, %acc: vector<5x1x1x20xf32>) -> vector<5x1x1x20xf32> {
+//       CHECK:   %[[RES:.+]] = vector.multi_reduction  <mul>, %[[SRC]], %[[ACCUM]] [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32>
+    %0 = vector.multi_reduction <mul>, %source, %acc [2] : vector<5x1x4x1x20xf32> to vector<5x1x1x20xf32>
+
+//       CHECK:     return %[[RES]] : vector<5x1x1x20xf32>
+    return %0 : vector<5x1x1x20xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @insert_strided_slice_full_range
 //  CHECK-SAME: %[[SOURCE:.+]]: vector<16x16xf16>, %{{.+}}: vector<16x16xf16>
 func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: vector<16x16xf16>) -> vector<16x16xf16> {


        


More information about the Mlir-commits mailing list