[Mlir-commits] [mlir] 039bdcc - [MLIR] Canonicalize sub/add of a constant and another sub/add of a constant
William S. Moses
llvmlistbot at llvm.org
Mon May 3 08:53:16 PDT 2021
Author: William S. Moses
Date: 2021-05-03T11:49:23-04:00
New Revision: 039bdcc0a8a213faa8f32837c44b81ce41f41ab0
URL: https://github.com/llvm/llvm-project/commit/039bdcc0a8a213faa8f32837c44b81ce41f41ab0
DIFF: https://github.com/llvm/llvm-project/commit/039bdcc0a8a213faa8f32837c44b81ce41f41ab0.diff
LOG: [MLIR] Canonicalize sub/add of a constant and another sub/add of a constant
Differential Revision: https://reviews.llvm.org/D101705
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 99b8889467e38..6152b6b4b41a6 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -277,6 +277,7 @@ def AddIOp : IntBinaryOp<"addi", [Commutative]> {
```
}];
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -1792,6 +1793,7 @@ def SubFOp : FloatBinaryOp<"subf"> {
def SubIOp : IntBinaryOp<"subi"> {
let summary = "integer subtraction operation";
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index fc73947acb9d7..51b832805cca8 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -283,6 +283,62 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
+/// Canonicalize a sum of a constant and (constant - something) to simply be
+/// a sum of constants minus something. This transformation does similar
+/// transformations for additions of a constant with a subtract/add of
+/// a constant. This may result in some operations being reordered (but should
+/// remain equivalent).
+struct AddConstantReorder : public OpRewritePattern<AddIOp> {
+ using OpRewritePattern<AddIOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AddIOp addop,
+ PatternRewriter &rewriter) const override {
+ for (int i = 0; i < 2; i++) {
+ APInt origConst;
+ APInt midConst;
+ if (matchPattern(addop.getOperand(i), m_ConstantInt(&origConst))) {
+ if (auto midAddOp = addop.getOperand(1 - i).getDefiningOp<AddIOp>()) {
+ for (int j = 0; j < 2; j++) {
+ if (matchPattern(midAddOp.getOperand(j),
+ m_ConstantInt(&midConst))) {
+ auto nextConstant = rewriter.create<ConstantOp>(
+ addop.getLoc(), rewriter.getIntegerAttr(
+ addop.getType(), origConst + midConst));
+ rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
+ midAddOp.getOperand(1 - j));
+ return success();
+ }
+ }
+ }
+ if (auto midSubOp = addop.getOperand(1 - i).getDefiningOp<SubIOp>()) {
+ if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+ auto nextConstant = rewriter.create<ConstantOp>(
+ addop.getLoc(),
+ rewriter.getIntegerAttr(addop.getType(), origConst + midConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(addop, nextConstant,
+ midSubOp.getOperand(1));
+ return success();
+ }
+ if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+ auto nextConstant = rewriter.create<ConstantOp>(
+ addop.getLoc(),
+ rewriter.getIntegerAttr(addop.getType(), origConst - midConst));
+ rewriter.replaceOpWithNewOp<AddIOp>(addop, nextConstant,
+ midSubOp.getOperand(0));
+ return success();
+ }
+ }
+ }
+ }
+ return failure();
+ }
+};
+
+void AddIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<AddConstantReorder>(context);
+}
+
//===----------------------------------------------------------------------===//
// AndOp
//===----------------------------------------------------------------------===//
@@ -1706,6 +1762,153 @@ OpFoldResult SubIOp::fold(ArrayRef<Attribute> operands) {
[](APInt a, APInt b) { return a - b; });
}
+/// Canonicalize a sub of a constant and (constant +/- something) to simply be
+/// a single operation that merges the two constants.
+struct SubConstantReorder : public OpRewritePattern<SubIOp> {
+ using OpRewritePattern<SubIOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubIOp subOp,
+ PatternRewriter &rewriter) const override {
+ APInt origConst;
+ APInt midConst;
+
+ if (matchPattern(subOp.getOperand(0), m_ConstantInt(&origConst))) {
+ if (auto midAddOp = subOp.getOperand(1).getDefiningOp<AddIOp>()) {
+ // origConst - (midConst + something) == (origConst - midConst) -
+ // something
+ for (int j = 0; j < 2; j++) {
+ if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+ midAddOp.getOperand(1 - j));
+ return success();
+ }
+ }
+ }
+
+ if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
+ if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+ // (midConst - something) - origConst == (midConst - origConst) -
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+ midSubOp.getOperand(1));
+ return success();
+ }
+
+ if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+ // (something - midConst) - origConst == something - (origConst +
+ // midConst)
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
+ nextConstant);
+ return success();
+ }
+ }
+
+ if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
+ if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+ // origConst - (midConst - something) == (origConst - midConst) +
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+ rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+ midSubOp.getOperand(1));
+ return success();
+ }
+
+ if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+ // origConst - (something - midConst) == (origConst + midConst) -
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst + midConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+ midSubOp.getOperand(0));
+ return success();
+ }
+ }
+ }
+
+ if (matchPattern(subOp.getOperand(1), m_ConstantInt(&origConst))) {
+ if (auto midAddOp = subOp.getOperand(0).getDefiningOp<AddIOp>()) {
+ // (midConst + something) - origConst == (midConst - origConst) +
+ // something
+ for (int j = 0; j < 2; j++) {
+ if (matchPattern(midAddOp.getOperand(j), m_ConstantInt(&midConst))) {
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+ rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+ midAddOp.getOperand(1 - j));
+ return success();
+ }
+ }
+ }
+
+ if (auto midSubOp = subOp.getOperand(0).getDefiningOp<SubIOp>()) {
+ if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+ // (midConst - something) - origConst == (midConst - origConst) -
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), midConst - origConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+ midSubOp.getOperand(1));
+ return success();
+ }
+
+ if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+ // (something - midConst) - origConst == something - (midConst +
+ // origConst)
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), midConst + origConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, midSubOp.getOperand(0),
+ nextConstant);
+ return success();
+ }
+ }
+
+ if (auto midSubOp = subOp.getOperand(1).getDefiningOp<SubIOp>()) {
+ if (matchPattern(midSubOp.getOperand(0), m_ConstantInt(&midConst))) {
+ // origConst - (midConst - something) == (origConst - midConst) +
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+ rewriter.replaceOpWithNewOp<AddIOp>(subOp, nextConstant,
+ midSubOp.getOperand(1));
+ return success();
+ }
+ if (matchPattern(midSubOp.getOperand(1), m_ConstantInt(&midConst))) {
+ // origConst - (something - midConst) == (origConst - midConst) -
+ // something
+ auto nextConstant = rewriter.create<ConstantOp>(
+ subOp.getLoc(),
+ rewriter.getIntegerAttr(subOp.getType(), origConst - midConst));
+ rewriter.replaceOpWithNewOp<SubIOp>(subOp, nextConstant,
+ midSubOp.getOperand(0));
+ return success();
+ }
+ }
+ }
+ return failure();
+ }
+};
+
+void SubIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context) {
+ results.insert<SubConstantReorder>(context);
+}
+
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 8db3065af47d0..15dbde7d2757f 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -428,3 +428,113 @@ func @truncConstant(%arg0: i8) -> i16 {
%tr = trunci %c-2 : i32 to i16
return %tr : i16
}
+
+// -----
+
+// CHECK-LABEL: @tripleAddAdd
+// CHECK: %[[cres:.+]] = constant 59 : index
+// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func @tripleAddAdd(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = addi %c17, %arg0 : index
+ %add2 = addi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddSub0
+// CHECK: %[[cres:.+]] = constant 59 : index
+// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
+// CHECK: return %[[add]]
+func @tripleAddSub0(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %c17, %arg0 : index
+ %add2 = addi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleAddSub1
+// CHECK: %[[cres:.+]] = constant 25 : index
+// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func @tripleAddSub1(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %arg0, %c17 : index
+ %add2 = addi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubAdd0
+// CHECK: %[[cres:.+]] = constant 25 : index
+// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
+// CHECK: return %[[add]]
+func @tripleSubAdd0(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = addi %c17, %arg0 : index
+ %add2 = subi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubAdd1
+// CHECK: %[[cres:.+]] = constant -25 : index
+// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func @tripleSubAdd1(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = addi %c17, %arg0 : index
+ %add2 = subi %add1, %c42 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub0
+// CHECK: %[[cres:.+]] = constant 25 : index
+// CHECK: %[[add:.+]] = addi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func @tripleSubSub0(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %c17, %arg0 : index
+ %add2 = subi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub1
+// CHECK: %[[cres:.+]] = constant -25 : index
+// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
+// CHECK: return %[[add]]
+func @tripleSubSub1(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %c17, %arg0 : index
+ %add2 = subi %add1, %c42 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub2
+// CHECK: %[[cres:.+]] = constant 59 : index
+// CHECK: %[[add:.+]] = subi %[[cres]], %arg0 : index
+// CHECK: return %[[add]]
+func @tripleSubSub2(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %arg0, %c17 : index
+ %add2 = subi %c42, %add1 : index
+ return %add2 : index
+}
+
+// CHECK-LABEL: @tripleSubSub3
+// CHECK: %[[cres:.+]] = constant 59 : index
+// CHECK: %[[add:.+]] = subi %arg0, %[[cres]] : index
+// CHECK: return %[[add]]
+func @tripleSubSub3(%arg0: index) -> index {
+ %c17 = constant 17 : index
+ %c42 = constant 42 : index
+ %add1 = subi %arg0, %c17 : index
+ %add2 = subi %add1, %c42 : index
+ return %add2 : index
+}
More information about the Mlir-commits
mailing list