[Mlir-commits] [mlir] e49ae62 - [mlir][Arith] Make --unsigned-when-equivalent use dialect conversion
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Jun 20 08:03:18 PDT 2022
Author: Krzysztof Drewniak
Date: 2022-06-20T15:03:07Z
New Revision: e49ae6284c38c7fef93b3b72af4c89a6e4836a45
URL: https://github.com/llvm/llvm-project/commit/e49ae6284c38c7fef93b3b72af4c89a6e4836a45
DIFF: https://github.com/llvm/llvm-project/commit/e49ae6284c38c7fef93b3b72af4c89a6e4836a45.diff
LOG: [mlir][Arith] Make --unsigned-when-equivalent use dialect conversion
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D128096
Added:
Modified:
mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index 5cecc69285bea..f84990d0a8c47 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -12,40 +12,55 @@
#include "mlir/Analysis/IntRangeAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::arith;
-using OpList = llvm::SmallVector<Operation *>;
-
-/// Returns true when a value is statically non-negative in that it has a lower
+/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
-static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
+static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+ Value v) {
Optional<ConstantIntRanges> result = analysis.getResult(v);
if (!result.hasValue())
- return false;
+ return failure();
const ConstantIntRanges &range = result.getValue();
- return (range.smin().isNonNegative());
+ return success(range.smin().isNonNegative());
}
-/// Identify all operations in a block that have signed equivalents and have
-/// operands and results that are statically non-negative.
-template <typename... Ts>
-static void getConvertableOps(Operation *root, OpList &toRewrite,
- IntRangeAnalysis &analysis) {
+/// Succeeds if an op can be converted to its unsigned equivalent without
+/// changing its semantics. This is the case when none of its openands or
+/// results can be below 0 when analyzed from a signed perspective.
+static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+ Operation *op) {
auto nonNegativePred = [&analysis](Value v) -> bool {
- return staticallyNonNegative(analysis, v);
+ return succeeded(staticallyNonNegative(analysis, v));
};
- root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
- if (isa<Ts...>(orig) &&
- llvm::all_of(orig->getOperands(), nonNegativePred) &&
- llvm::all_of(orig->getResults(), nonNegativePred)) {
- toRewrite.push_back(orig);
- }
- });
+ return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
+ llvm::all_of(op->getResults(), nonNegativePred));
}
+/// Succeeds when the comparison predicate is a signed operation and all the
+/// operands are non-negative, indicating that the cmpi operation `op` can have
+/// its predicate changed to an unsigned equivalent.
+static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) {
+ CmpIPredicate pred = op.getPredicate();
+ switch (pred) {
+ case CmpIPredicate::sle:
+ case CmpIPredicate::slt:
+ case CmpIPredicate::sge:
+ case CmpIPredicate::sgt:
+ return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool {
+ return succeeded(staticallyNonNegative(analysis, v));
+ }));
+ default:
+ return failure();
+ }
+}
+
+/// Return the unsigned equivalent of a signed comparison predicate,
+/// or the predicate itself if there is none.
static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
switch (pred) {
case CmpIPredicate::sle:
@@ -61,72 +76,30 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
}
}
-/// Find all cmpi ops that can be replaced by their unsigned equivalents.
-static void getConvertableCmpi(Operation *root, OpList &toRewrite,
- IntRangeAnalysis &analysis) {
- auto nonNegativePred = [&analysis](Value v) -> bool {
- return staticallyNonNegative(analysis, v);
- };
- root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
- CmpIPredicate pred = orig.getPredicate();
- if (toUnsignedPred(pred) != pred &&
- // i1 will spuriously and trivially show up as pontentially negative,
- // so don't check the results
- llvm::all_of(orig->getOperands(), nonNegativePred)) {
- toRewrite.push_back(orig.getOperation());
- }
- });
-}
-
-/// Return ops to be replaced in the order they should be rewritten.
-static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
- OpList ret;
- getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
- MaxSIOp, ExtSIOp>(root, ret, analysis);
- // Since these are in-place changes, they don't need to be topological order
- // like the others.
- getConvertableCmpi(root, ret, analysis);
- return ret;
-}
+namespace {
+template <typename Signed, typename Unsigned>
+struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
+ using OpConversionPattern<Signed>::OpConversionPattern;
-template <typename T, typename U>
-static bool rewriteOp(Operation *op, OpBuilder &b) {
- if (isa<T>(op)) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPoint(op);
- Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
- op->getOperands(), op->getAttrs());
- op->replaceAllUsesWith(newOp->getResults());
- op->erase();
- return true;
+ LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
+ ConversionPatternRewriter &rw) const override {
+ rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
+ adaptor.getOperands(), op->getAttrs());
+ return success();
}
- return false;
-}
+};
-static bool rewriteCmpI(Operation *op, OpBuilder &b) {
- if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
- cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
- b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
- return true;
- }
- return false;
-}
+struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
+ using OpConversionPattern<CmpIOp>::OpConversionPattern;
-static void rewrite(Operation *root, const OpList &toReplace) {
- OpBuilder b(root->getContext());
- b.setInsertionPoint(root);
- for (Operation *op : toReplace) {
- rewriteOp<DivSIOp, DivUIOp>(op, b) ||
- rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b) ||
- rewriteOp<FloorDivSIOp, DivUIOp>(op, b) ||
- rewriteOp<RemSIOp, RemUIOp>(op, b) ||
- rewriteOp<MinSIOp, MinUIOp>(op, b) ||
- rewriteOp<MaxSIOp, MaxUIOp>(op, b) ||
- rewriteOp<ExtSIOp, ExtUIOp>(op, b) || rewriteCmpI(op, b);
+ LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
+ ConversionPatternRewriter &rw) const override {
+ rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
+ op.getLhs(), op.getRhs());
+ return success();
}
-}
+};
-namespace {
struct ArithmeticUnsignedWhenEquivalentPass
: public ArithmeticUnsignedWhenEquivalentBase<
ArithmeticUnsignedWhenEquivalentPass> {
@@ -135,8 +108,35 @@ struct ArithmeticUnsignedWhenEquivalentPass
/// ensures that analysis results are not invalidated during rewriting.
void runOnOperation() override {
Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
IntRangeAnalysis analysis(op);
- rewrite(op, getMatching(op, analysis));
+
+ ConversionTarget target(*ctx);
+ target.addLegalDialect<ArithmeticDialect>();
+ target
+ .addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
+ RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
+ [&analysis](Operation *op) -> Optional<bool> {
+ return failed(staticallyNonNegative(analysis, op));
+ });
+ target.addDynamicallyLegalOp<CmpIOp>(
+ [&analysis](CmpIOp op) -> Optional<bool> {
+ return failed(isCmpIConvertable(analysis, op));
+ });
+
+ RewritePatternSet patterns(ctx);
+ patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
+ ctx);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
+ signalPassFailure();
+ }
}
};
} // end anonymous namespace
More information about the Mlir-commits
mailing list