[Mlir-commits] [mlir] 8cb785c - [mlir][arith] Clean up ExpandOps pass

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 20 13:59:24 PST 2021


Author: Mogball
Date: 2021-12-20T21:59:11Z
New Revision: 8cb785cad12b2d0fb7de2e13f208dab73f27111a

URL: https://github.com/llvm/llvm-project/commit/8cb785cad12b2d0fb7de2e13f208dab73f27111a
DIFF: https://github.com/llvm/llvm-project/commit/8cb785cad12b2d0fb7de2e13f208dab73f27111a.diff

LOG: [mlir][arith] Clean up ExpandOps pass

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index d9ab927512387..d06c3043664dd 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -15,6 +15,13 @@
 
 using namespace mlir;
 
+/// Create an integer or index constant.
+static Value createConst(Location loc, Type type, int value,
+                         PatternRewriter &rewriter) {
+  return rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getIntegerAttr(type, value));
+}
+
 namespace {
 
 /// Expands CeilDivUIOp (n, m) into
@@ -26,17 +33,14 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
     Location loc = op.getLoc();
     Value a = op.getLhs();
     Value b = op.getRhs();
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(a.getType(), 0));
+    Value zero = createConst(loc, a.getType(), 0, rewriter);
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
-    Value one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(a.getType(), 1));
+    Value one = createConst(loc, a.getType(), 1, rewriter);
     Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
     Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
     Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
-    Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compare, zero, plusOne);
     return success();
   }
 };
@@ -49,16 +53,12 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
-    Type type = signedCeilDivIOp.getType();
-    Value a = signedCeilDivIOp.getLhs();
-    Value b = signedCeilDivIOp.getRhs();
-    Value plusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 1));
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 0));
-    Value minusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, -1));
+    Type type = op.getType();
+    Value a = op.getLhs();
+    Value b = op.getRhs();
+    Value plusOne = createConst(loc, type, 1, rewriter);
+    Value zero = createConst(loc, type, 0, rewriter);
+    Value minusOne = createConst(loc, type, -1, rewriter);
     // Compute x = (b>0) ? -1 : 1.
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
@@ -90,9 +90,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
     Value compareRes =
         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
     // Perform substitution and return success.
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, posRes, negRes);
     return success();
   }
 };
@@ -105,16 +104,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
                                 PatternRewriter &rewriter) const final {
     Location loc = op.getLoc();
-    arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
-    Type type = signedFloorDivIOp.getType();
-    Value a = signedFloorDivIOp.getLhs();
-    Value b = signedFloorDivIOp.getRhs();
-    Value plusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 1));
-    Value zero = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, 0));
-    Value minusOne = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getIntegerAttr(type, -1));
+    Type type = op.getType();
+    Value a = op.getLhs();
+    Value b = op.getRhs();
+    Value plusOne = createConst(loc, type, 1, rewriter);
+    Value zero = createConst(loc, type, 0, rewriter);
+    Value minusOne = createConst(loc, type, -1, rewriter);
     // Compute x = (b<0) ? 1 : -1.
     Value compare =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
@@ -144,9 +139,8 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
     Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
     Value compareRes =
         rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
     // Perform substitution and return success.
-    rewriter.replaceOp(op, {res});
+    rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, negRes, posRes);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list