[Mlir-commits] [mlir] 039bdcc - [MLIR] Canonicalize sub/add of a constant and another sub/add of a constant

William S. Moses llvmlistbot at llvm.org
Mon May 3 08:53:16 PDT 2021


Author: William S. Moses
Date: 2021-05-03T11:49:23-04:00
New Revision: 039bdcc0a8a213faa8f32837c44b81ce41f41ab0

URL: https://github.com/llvm/llvm-project/commit/039bdcc0a8a213faa8f32837c44b81ce41f41ab0
DIFF: https://github.com/llvm/llvm-project/commit/039bdcc0a8a213faa8f32837c44b81ce41f41ab0.diff

LOG: [MLIR] Canonicalize sub/add of a constant and another sub/add of a constant

Differential Revision: https://reviews.llvm.org/D101705

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 99b8889467e38..6152b6b4b41a6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -277,6 +277,7 @@ def AddIOp : IntBinaryOp<"addi", [Commutative]> {
     ```
   }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1792,6 +1793,7 @@ def SubFOp : FloatBinaryOp<"subf"> {
 def SubIOp : IntBinaryOp<"subi"> {
   let summary = "integer subtraction operation";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index fc73947acb9d7..51b832805cca8 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+/// Canonicalize a sum of a constant and (constant - something) to simply be
+/// a sum of constants minus something. This transformation does similar
+/// transformations for additions of a constant with a subtract/add of
+/// a constant. This may result in some operations being reordered (but should
+/// remain equivalent).
+struct AddConstantReorder : public OpRewritePattern<AddIOp> {
+  using OpRewritePattern<AddIOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AddIOp addop,
+                                PatternRewriter &rewriter) const override {
+    for (int i = 0; i < 2; i++) {
+      APInt origConst;
+      APInt midConst;
+      if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
+        if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
+          for (int j = 0; j < 2; j++) {
+            if (matchPattern(midAddOp.getOperand(j),
+                             m_ConstantInt(&midConst))) {
+              auto nextConstant = rewriter.create<ConstantOp>(
+                  addop.getLoc(), rewriter.getIntegerAttr(
+                                      addop.getType(), origConst + midConst));
+              rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
+                                                  midAddOp.getOperand(1 - j));
+              return success();
+            }
+          }
+        }
+        if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
+          if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+            auto nextConstant = rewriter.create<ConstantOp>(
+                addop.getLoc(),
+                rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
+            rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
+                                                midSubOp.getOperand(1));
+            return success();
+          }
+          if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+            auto nextConstant = rewriter.create<ConstantOp>(
+                addop.getLoc(),
+                rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
+            rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
+                                                midSubOp.getOperand(0));
+            return success();
+          }
+        }
+      }
+    }
+    return failure();
+  }
+};
+
+void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  results.insert<AddConstantReorder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // AndOp
 //===----------------------------------------------------------------------===//
@@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
                                         [](APInt a, APInt b) { return a - b; });
 }
 
+/// Canonicalize a sub of a constant and (constant +/- something) to simply be
+/// a single operation that merges the two constants.
+struct SubConstantReorder : public OpRewritePattern<SubIOp> {
+  using OpRewritePattern<SubIOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubIOp subOp,
+                                PatternRewriter &rewriter) const override {
+    APInt origConst;
+    APInt midConst;
+
+    if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
+      if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
+        // origConst - (midConst + something) == (origConst - midConst) -
+        // something
+        for (int j = 0; j < 2; j++) {
+          if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
+            auto nextConstant = rewriter.create<ConstantOp>(
+                subOp.getLoc(),
+                rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+            rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+                                                midAddOp.getOperand(1 - j));
+            return success();
+          }
+        }
+      }
+
+      if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
+        if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+          // (midConst - something) - origConst == (midConst - origConst) -
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(1));
+          return success();
+        }
+
+        if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+          // (something - midConst) - origConst == something - (origConst +
+          // midConst)
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
+                                              nextConstant);
+          return success();
+        }
+      }
+
+      if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
+        if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+          // origConst - (midConst - something) == (origConst - midConst) +
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+          rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(1));
+          return success();
+        }
+
+        if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+          // origConst - (something - midConst) == (origConst + midConst) -
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(0));
+          return success();
+        }
+      }
+    }
+
+    if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
+      if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
+        // (midConst + something) - origConst == (midConst - origConst) +
+        // something
+        for (int j = 0; j < 2; j++) {
+          if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
+            auto nextConstant = rewriter.create<ConstantOp>(
+                subOp.getLoc(),
+                rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+            rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+                                                midAddOp.getOperand(1 - j));
+            return success();
+          }
+        }
+      }
+
+      if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
+        if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+          // (midConst - something) - origConst == (midConst - origConst) -
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(1));
+          return success();
+        }
+
+        if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+          // (something - midConst) - origConst == something - (midConst +
+          // origConst)
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
+                                              nextConstant);
+          return success();
+        }
+      }
+
+      if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
+        if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+          // origConst - (midConst - something) == (origConst - midConst) +
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+          rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(1));
+          return success();
+        }
+        if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+          // origConst - (something - midConst) == (origConst - midConst) -
+          // something
+          auto nextConstant = rewriter.create<ConstantOp>(
+              subOp.getLoc(),
+              rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+          rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+                                              midSubOp.getOperand(0));
+          return success();
+        }
+      }
+    }
+    return failure();
+  }
+};
+
+void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                         MLIRContext *context) {
+  results.insert<SubConstantReorder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 8db3065af47d0..15dbde7d2757f 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -428,3 +428,113 @@ func @truncConstant(%arg0: i8) -> i16 {
   %tr = trunci %c-2 : i32 to i16
   return %tr : i16
 }
+
+// -----
+
+// CHECK-LABEL: @tripleAddAdd
+//       CHECK:   %[[cres:.+]] = constant 59 : index 
+//       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 
+//       CHECK:   return %[[add]]
+func @tripleAddAdd(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = addi %c17, %arg0 : index
+  %add2 = addi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddSub0
+//       CHECK:   %[[cres:.+]] = constant 59 : index 
+//       CHECK:   %[[add:.+]] = subi %[[cres]], %arg0 : index 
+//       CHECK:   return %[[add]]
+func @tripleAddSub0(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %c17, %arg0 : index
+  %add2 = addi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddSub1
+//       CHECK:   %[[cres:.+]] = constant 25 : index 
+//       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 
+//       CHECK:   return %[[add]]
+func @tripleAddSub1(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %arg0, %c17 : index
+  %add2 = addi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubAdd0
+//       CHECK:   %[[cres:.+]] = constant 25 : index 
+//       CHECK:   %[[add:.+]] = subi %[[cres]], %arg0 : index 
+//       CHECK:   return %[[add]]
+func @tripleSubAdd0(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = addi %c17, %arg0 : index
+  %add2 = subi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubAdd1
+//       CHECK:   %[[cres:.+]] = constant -25 : index 
+//       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 
+//       CHECK:   return %[[add]]
+func @tripleSubAdd1(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = addi %c17, %arg0 : index
+  %add2 = subi %add1, %c42 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub0
+//       CHECK:   %[[cres:.+]] = constant 25 : index 
+//       CHECK:   %[[add:.+]] = addi %arg0, %[[cres]] : index 
+//       CHECK:   return %[[add]]
+func @tripleSubSub0(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %c17, %arg0 : index
+  %add2 = subi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub1
+//       CHECK:   %[[cres:.+]] = constant -25 : index 
+//       CHECK:   %[[add:.+]] = subi %[[cres]], %arg0 : index 
+//       CHECK:   return %[[add]]
+func @tripleSubSub1(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %c17, %arg0 : index
+  %add2 = subi %add1, %c42 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub2
+//       CHECK:   %[[cres:.+]] = constant 59 : index 
+//       CHECK:   %[[add:.+]] = subi %[[cres]], %arg0 : index 
+//       CHECK:   return %[[add]]
+func @tripleSubSub2(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %arg0, %c17 : index
+  %add2 = subi %c42, %add1 : index
+  return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub3
+//       CHECK:   %[[cres:.+]] = constant 59 : index 
+//       CHECK:   %[[add:.+]] = subi %arg0, %[[cres]] : index 
+//       CHECK:   return %[[add]]
+func @tripleSubSub3(%arg0: index) -> index {
+  %c17 = constant 17 : index
+  %c42 = constant 42 : index
+  %add1 = subi %arg0, %c17 : index
+  %add2 = subi %add1, %c42 : index
+  return %add2 : index
+}


        


More information about the Mlir-commits mailing list