[Mlir-commits] [mlir] e49ae62 - [mlir][Arith] Make --unsigned-when-equivalent use dialect conversion

Krzysztof Drewniak llvmlistbot at llvm.org
Mon Jun 20 08:03:18 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-06-20T15:03:07Z
New Revision: e49ae6284c38c7fef93b3b72af4c89a6e4836a45

URL: https://github.com/llvm/llvm-project/commit/e49ae6284c38c7fef93b3b72af4c89a6e4836a45
DIFF: https://github.com/llvm/llvm-project/commit/e49ae6284c38c7fef93b3b72af4c89a6e4836a45.diff

LOG: [mlir][Arith] Make --unsigned-when-equivalent use dialect conversion

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D128096

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index 5cecc69285bea..f84990d0a8c47 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -12,40 +12,55 @@
 #include "mlir/Analysis/IntRangeAnalysis.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
 using namespace mlir::arith;
 
-using OpList = llvm::SmallVector<Operation *>;
-
-/// Returns true when a value is statically non-negative in that it has a lower
+/// Succeeds when a value is statically non-negative in that it has a lower
 /// bound on its value (if it is treated as signed) and that bound is
 /// non-negative.
-static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
+static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+                                           Value v) {
   Optional<ConstantIntRanges> result = analysis.getResult(v);
   if (!result.hasValue())
-    return false;
+    return failure();
   const ConstantIntRanges &range = result.getValue();
-  return (range.smin().isNonNegative());
+  return success(range.smin().isNonNegative());
 }
 
-/// Identify all operations in a block that have signed equivalents and have
-/// operands and results that are statically non-negative.
-template <typename... Ts>
-static void getConvertableOps(Operation *root, OpList &toRewrite,
-                              IntRangeAnalysis &analysis) {
+/// Succeeds if an op can be converted to its unsigned equivalent without
+/// changing its semantics. This is the case when none of its openands or
+/// results can be below 0 when analyzed from a signed perspective.
+static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+                                           Operation *op) {
   auto nonNegativePred = [&analysis](Value v) -> bool {
-    return staticallyNonNegative(analysis, v);
+    return succeeded(staticallyNonNegative(analysis, v));
   };
-  root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
-    if (isa<Ts...>(orig) &&
-        llvm::all_of(orig->getOperands(), nonNegativePred) &&
-        llvm::all_of(orig->getResults(), nonNegativePred)) {
-      toRewrite.push_back(orig);
-    }
-  });
+  return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
+                 llvm::all_of(op->getResults(), nonNegativePred));
 }
 
+/// Succeeds when the comparison predicate is a signed operation and all the
+/// operands are non-negative, indicating that the cmpi operation `op` can have
+/// its predicate changed to an unsigned equivalent.
+static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) {
+  CmpIPredicate pred = op.getPredicate();
+  switch (pred) {
+  case CmpIPredicate::sle:
+  case CmpIPredicate::slt:
+  case CmpIPredicate::sge:
+  case CmpIPredicate::sgt:
+    return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool {
+      return succeeded(staticallyNonNegative(analysis, v));
+    }));
+  default:
+    return failure();
+  }
+}
+
+/// Return the unsigned equivalent of a signed comparison predicate,
+/// or the predicate itself if there is none.
 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
   switch (pred) {
   case CmpIPredicate::sle:
@@ -61,72 +76,30 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
   }
 }
 
-/// Find all cmpi ops that can be replaced by their unsigned equivalents.
-static void getConvertableCmpi(Operation *root, OpList &toRewrite,
-                               IntRangeAnalysis &analysis) {
-  auto nonNegativePred = [&analysis](Value v) -> bool {
-    return staticallyNonNegative(analysis, v);
-  };
-  root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
-    CmpIPredicate pred = orig.getPredicate();
-    if (toUnsignedPred(pred) != pred &&
-        // i1 will spuriously and trivially show up as pontentially negative,
-        // so don't check the results
-        llvm::all_of(orig->getOperands(), nonNegativePred)) {
-      toRewrite.push_back(orig.getOperation());
-    }
-  });
-}
-
-/// Return ops to be replaced in the order they should be rewritten.
-static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
-  OpList ret;
-  getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
-                    MaxSIOp, ExtSIOp>(root, ret, analysis);
-  // Since these are in-place changes, they don't need to be topological order
-  // like the others.
-  getConvertableCmpi(root, ret, analysis);
-  return ret;
-}
+namespace {
+template <typename Signed, typename Unsigned>
+struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
+  using OpConversionPattern<Signed>::OpConversionPattern;
 
-template <typename T, typename U>
-static bool rewriteOp(Operation *op, OpBuilder &b) {
-  if (isa<T>(op)) {
-    OpBuilder::InsertionGuard guard(b);
-    b.setInsertionPoint(op);
-    Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
-                                   op->getOperands(), op->getAttrs());
-    op->replaceAllUsesWith(newOp->getResults());
-    op->erase();
-    return true;
+  LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
+                                ConversionPatternRewriter &rw) const override {
+    rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
+                                    adaptor.getOperands(), op->getAttrs());
+    return success();
   }
-  return false;
-}
+};
 
-static bool rewriteCmpI(Operation *op, OpBuilder &b) {
-  if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
-    cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
-        b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
-    return true;
-  }
-  return false;
-}
+struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
+  using OpConversionPattern<CmpIOp>::OpConversionPattern;
 
-static void rewrite(Operation *root, const OpList &toReplace) {
-  OpBuilder b(root->getContext());
-  b.setInsertionPoint(root);
-  for (Operation *op : toReplace) {
-    rewriteOp<DivSIOp, DivUIOp>(op, b) ||
-        rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) ||
-        rewriteOp<FloorDivSIOp, DivUIOp>(op, b) ||
-        rewriteOp<RemSIOp, RemUIOp>(op, b) ||
-        rewriteOp<MinSIOp, MinUIOp>(op, b) ||
-        rewriteOp<MaxSIOp, MaxUIOp>(op, b) ||
-        rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b);
+  LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
+                                ConversionPatternRewriter &rw) const override {
+    rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
+                                  op.getLhs(), op.getRhs());
+    return success();
   }
-}
+};
 
-namespace {
 struct ArithmeticUnsignedWhenEquivalentPass
     : public ArithmeticUnsignedWhenEquivalentBase<
           ArithmeticUnsignedWhenEquivalentPass> {
@@ -135,8 +108,35 @@ struct ArithmeticUnsignedWhenEquivalentPass
   /// ensures that analysis results are not invalidated during rewriting.
   void runOnOperation() override {
     Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
     IntRangeAnalysis analysis(op);
-    rewrite(op, getMatching(op, analysis));
+
+    ConversionTarget target(*ctx);
+    target.addLegalDialect<ArithmeticDialect>();
+    target
+        .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
+                               RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
+            [&analysis](Operation *op) -> Optional<bool> {
+              return failed(staticallyNonNegative(analysis, op));
+            });
+    target.addDynamicallyLegalOp<CmpIOp>(
+        [&analysis](CmpIOp op) -> Optional<bool> {
+          return failed(isCmpIConvertable(analysis, op));
+        });
+
+    RewritePatternSet patterns(ctx);
+    patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+                 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+                 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+                 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+                 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+                 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+                 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+        ctx);
+
+    if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+      signalPassFailure();
+    }
   }
 };
 } // end anonymous namespace


        


More information about the Mlir-commits mailing list