[Mlir-commits] [mlir] df538fd - [mlir][affine] Add single result affine.min/max -> affine.apply canonicalization.

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jul 14 13:38:12 PDT 2021


Author: Nicolas Vasilache
Date: 2021-07-14T20:38:06Z
New Revision: df538fdaa985e7e64a572dd84b878c562899e5e8

URL: https://github.com/llvm/llvm-project/commit/df538fdaa985e7e64a572dd84b878c562899e5e8
DIFF: https://github.com/llvm/llvm-project/commit/df538fdaa985e7e64a572dd84b878c562899e5e8.diff

LOG: [mlir][affine] Add single result affine.min/max -> affine.apply canonicalization.

Differential Revision: https://reviews.llvm.org/D106014

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/test/Dialect/Affine/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9fa53a5fcdcf1..1d496574639c8 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2538,6 +2538,20 @@ struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
   }
 };
 
+template <typename T>
+struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(T affineOp,
+                                PatternRewriter &rewriter) const override {
+    if (affineOp.map().getNumResults() != 1)
+      return failure();
+    rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(),
+                                               affineOp.getOperands());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // AffineMinOp
 //===----------------------------------------------------------------------===//
@@ -2551,7 +2565,8 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
 
 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<DeduplicateAffineMinMaxExpressions<AffineMinOp>,
+  patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
+               DeduplicateAffineMinMaxExpressions<AffineMinOp>,
                MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
       context);
 }
@@ -2569,7 +2584,8 @@ OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
 
 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {
-  patterns.add<DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
+  patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
+               DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
                MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
       context);
 }

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 3d6bd57c27ffc..ba15d2a570244 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -870,7 +870,6 @@ func @dont_merge_affine_max_if_not_single_dim(%i0: index, %i1: index, %i2: index
   return %1: index
 }
 
-
 // -----
 
 // CHECK-LABEL: func @dont_merge_affine_max_if_not_single_sym
@@ -936,3 +935,21 @@ func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref<memref<?xi8>>) {
   return
 }
 
+// -----
+
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 + 16)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK: func @canonicalize_single_min_max
+// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index)
+func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) {
+  // CHECK-NOT: affine.min
+  // CHECK-NEXT: affine.apply #[[$MAP0]]()[%[[I0]]]
+  %0 = affine.min affine_map<()[s0] -> (s0 + 16)> ()[%i0]
+
+  // CHECK-NOT: affine.max
+  // CHECK-NEXT: affine.apply #[[$MAP1]]()[%[[I1]]]
+  %1 = affine.min affine_map<()[s0] -> (s0 * 4)> ()[%i1]
+
+  return %0, %1: index, index
+}


        


More information about the Mlir-commits mailing list