[Mlir-commits] [mlir] b0b0043 - [mlir][Arith] Pass to switch signed ops for equivalent unsigned ones

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Jun 14 14:18:34 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-06-14T21:18:29Z
New Revision: b0b00432093be9680ed833af642bcafc3ca11586

URL: https://github.com/llvm/llvm-project/commit/b0b00432093be9680ed833af642bcafc3ca11586
DIFF: https://github.com/llvm/llvm-project/commit/b0b00432093be9680ed833af642bcafc3ca11586.diff

LOG: [mlir][Arith] Pass to switch signed ops for equivalent unsigned ones

If all the arguments to and results of an operation are known to be
non-negative when interpreted as signed (which also implies that all
computations producing those values did not experience signed
overflow), we can replace that operation with an equivalent one that
operates on unsigned values.

Such a replacement, when it is possible, can provide useful hints to
backends, such as by allowing LLVM to replace remainder with bitwise
operations in more cases.

Depends on D124022

Depends on D124023

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D124024

Added: 
    mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
    mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
    mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
    mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
index 1acea57102dbd..9b9331f23230a 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
@@ -26,6 +26,10 @@ void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
 /// Create a pass to legalize Arithmetic ops for LLVM lowering.
 std::unique_ptr<Pass> createArithmeticExpandOpsPass();
 
+/// Create a pass to replace signed ops with unsigned ones where they are proven
+/// equivalent.
+std::unique_ptr<Pass> createArithmeticUnsignedWhenEquivalentPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
index 1d84e2777b1cc..752d715087959 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
@@ -33,4 +33,20 @@ def ArithmeticExpandOps : Pass<"arith-expand"> {
   let constructor = "mlir::arith::createArithmeticExpandOpsPass()";
 }
 
+def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
+  let summary = "Replace signed ops with unsigned ones where they are proven equivalent";
+  let description = [{
+    Replace signed ops with their unsigned equivalents when integer range analysis
+    determines that their arguments and results are all guaranteed to be
+    non-negative when interpreted as signed integers. When this occurs,
+    we know that the semantics of the signed and unsigned operations are the same,
+    since they share the same behavior when their operands and results  are in the
+    range [0, signed_max(type)].
+
+    The affect ops include division, remainder, shifts, min, max, and integer
+    comparisons.
+  }];
+  let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
+}
+
 #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
index 7f5c9ca82cec7..f140715e603ee 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ExpandOps.cpp
+  UnsignedWhenEquivalent.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Transforms
@@ -10,9 +11,11 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
   MLIRArithmeticTransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRAnalysis
   MLIRArithmeticDialect
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
   MLIRPass

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
new file mode 100644
index 0000000000000..30fb51725dcb0
--- /dev/null
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -0,0 +1,144 @@
+//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
+// unsigned
+// ones when all their arguments and results are statically non-negative --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+
+using OpList = llvm::SmallVector<Operation *>;
+
+/// Returns true when a value is statically non-negative in that it has a lower
+/// bound on its value (if it is treated as signed) and that bound is
+/// non-negative.
+static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) {
+  Optional<ConstantIntRanges> result = analysis.getResult(v);
+  if (!result.hasValue())
+    return false;
+  const ConstantIntRanges &range = result.getValue();
+  return (range.smin().isNonNegative());
+}
+
+/// Identify all operations in a block that have signed equivalents and have
+/// operands and results that are statically non-negative.
+template <typename... Ts>
+static void getConvertableOps(Operation *root, OpList &toRewrite,
+                              IntRangeAnalysis &analysis) {
+  auto nonNegativePred = [&analysis](Value v) -> bool {
+    return staticallyNonNegative(analysis, v);
+  };
+  root->walk([&nonNegativePred, &toRewrite](Operation *orig) {
+    if (isa<Ts...>(orig) &&
+        llvm::all_of(orig->getOperands(), nonNegativePred) &&
+        llvm::all_of(orig->getResults(), nonNegativePred)) {
+      toRewrite.push_back(orig);
+    }
+  });
+}
+
+static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
+  switch (pred) {
+  case CmpIPredicate::sle:
+    return CmpIPredicate::ule;
+  case CmpIPredicate::slt:
+    return CmpIPredicate::ult;
+  case CmpIPredicate::sge:
+    return CmpIPredicate::uge;
+  case CmpIPredicate::sgt:
+    return CmpIPredicate::ugt;
+  default:
+    return pred;
+  }
+}
+
+/// Find all cmpi ops that can be replaced by their unsigned equivalents.
+static void getConvertableCmpi(Operation *root, OpList &toRewrite,
+                               IntRangeAnalysis &analysis) {
+  auto nonNegativePred = [&analysis](Value v) -> bool {
+    return staticallyNonNegative(analysis, v);
+  };
+  root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) {
+    CmpIPredicate pred = orig.getPredicate();
+    if (toUnsignedPred(pred) != pred &&
+        // i1 will spuriously and trivially show up as pontentially negative,
+        // so don't check the results
+        llvm::all_of(orig->getOperands(), nonNegativePred)) {
+      toRewrite.push_back(orig.getOperation());
+    }
+  });
+}
+
+/// Return ops to be replaced in the order they should be rewritten.
+static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) {
+  OpList ret;
+  getConvertableOps<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, MinSIOp,
+                    MaxSIOp, ExtSIOp>(root, ret, analysis);
+  // Since these are in-place changes, they don't need to be topological order
+  // like the others.
+  getConvertableCmpi(root, ret, analysis);
+  return ret;
+}
+
+template <typename T, typename U>
+static void rewriteOp(Operation *op, OpBuilder &b) {
+  if (isa<T>(op)) {
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPoint(op);
+    Operation *newOp = b.create<U>(op->getLoc(), op->getResultTypes(),
+                                   op->getOperands(), op->getAttrs());
+    op->replaceAllUsesWith(newOp->getResults());
+    op->erase();
+  }
+}
+
+static void rewriteCmpI(Operation *op, OpBuilder &b) {
+  if (auto cmpOp = dyn_cast<CmpIOp>(op)) {
+    cmpOp.setPredicateAttr(CmpIPredicateAttr::get(
+        b.getContext(), toUnsignedPred(cmpOp.getPredicate())));
+  }
+}
+
+static void rewrite(Operation *root, const OpList &toReplace) {
+  OpBuilder b(root->getContext());
+  b.setInsertionPoint(root);
+  for (Operation *op : toReplace) {
+    rewriteOp<DivSIOp, DivUIOp>(op, b);
+    rewriteOp<CeilDivSIOp, CeilDivUIOp>(op, b);
+    rewriteOp<FloorDivSIOp, DivUIOp>(op, b);
+    rewriteOp<RemSIOp, RemUIOp>(op, b);
+    rewriteOp<MinSIOp, MinUIOp>(op, b);
+    rewriteOp<MaxSIOp, MaxUIOp>(op, b);
+    rewriteOp<ExtSIOp, ExtUIOp>(op, b);
+    rewriteCmpI(op, b);
+  }
+}
+
+namespace {
+struct ArithmeticUnsignedWhenEquivalentPass
+    : public ArithmeticUnsignedWhenEquivalentBase<
+          ArithmeticUnsignedWhenEquivalentPass> {
+  /// Implementation structure: first find all equivalent ops and collect them,
+  /// then perform all the rewrites in a second pass over the target op. This
+  /// ensures that analysis results are not invalidated during rewriting.
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    IntRangeAnalysis analysis(op);
+    rewrite(op, getMatching(op, analysis));
+  }
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass>
+mlir::arith::createArithmeticUnsignedWhenEquivalentPass() {
+  return std::make_unique<ArithmeticUnsignedWhenEquivalentPass>();
+}

diff  --git a/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir
new file mode 100644
index 0000000000000..558c9f4be5b9e
--- /dev/null
+++ b/mlir/test/Dialect/Arithmetic/unsigned-when-equivalent.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt -arith-unsigned-when-equivalent %s | FileCheck %s
+
+// CHECK-LABEL func @not_with_maybe_overflow
+// CHECK: arith.divsi
+// CHECK: arith.ceildivsi
+// CHECK: arith.floordivsi
+// CHECK: arith.remsi
+// CHECK: arith.minsi
+// CHECK: arith.maxsi
+// CHECK: arith.extsi
+// CHECK: arith.cmpi sle
+// CHECK: arith.cmpi slt
+// CHECK: arith.cmpi sge
+// CHECK: arith.cmpi sgt
+func.func @not_with_maybe_overflow(%arg0 : i32) {
+    %ci32_smax = arith.constant 0x7fffffff : i32
+    %c1 = arith.constant 1 : i32
+    %c4 = arith.constant 4 : i32
+    %0 = arith.minui %arg0, %ci32_smax : i32
+    %1 = arith.addi %0, %c1 : i32
+    %2 = arith.divsi %1, %c4 : i32
+    %3 = arith.ceildivsi %1, %c4 : i32
+    %4 = arith.floordivsi %1, %c4 : i32
+    %5 = arith.remsi %1, %c4 : i32
+    %6 = arith.minsi %1, %c4 : i32
+    %7 = arith.maxsi %1, %c4 : i32
+    %8 = arith.extsi %1 : i32 to i64
+    %9 = arith.cmpi sle, %1, %c4 : i32
+    %10 = arith.cmpi slt, %1, %c4 : i32
+    %11 = arith.cmpi sge, %1, %c4 : i32
+    %12 = arith.cmpi sgt, %1, %c4 : i32
+    func.return
+}
+
+// CHECK-LABEL func @yes_with_no_overflow
+// CHECK: arith.divui
+// CHECK: arith.ceildivui
+// CHECK: arith.divui
+// CHECK: arith.remui
+// CHECK: arith.minui
+// CHECK: arith.maxui
+// CHECK: arith.extui
+// CHECK: arith.cmpi ule
+// CHECK: arith.cmpi ult
+// CHECK: arith.cmpi uge
+// CHECK: arith.cmpi ugt
+func.func @yes_with_no_overflow(%arg0 : i32) {
+    %ci32_almost_smax = arith.constant 0x7ffffffe : i32
+    %c1 = arith.constant 1 : i32
+    %c4 = arith.constant 4 : i32
+    %0 = arith.minui %arg0, %ci32_almost_smax : i32
+    %1 = arith.addi %0, %c1 : i32
+    %2 = arith.divsi %1, %c4 : i32
+    %3 = arith.ceildivsi %1, %c4 : i32
+    %4 = arith.floordivsi %1, %c4 : i32
+    %5 = arith.remsi %1, %c4 : i32
+    %6 = arith.minsi %1, %c4 : i32
+    %7 = arith.maxsi %1, %c4 : i32
+    %8 = arith.extsi %1 : i32 to i64
+    %9 = arith.cmpi sle, %1, %c4 : i32
+    %10 = arith.cmpi slt, %1, %c4 : i32
+    %11 = arith.cmpi sge, %1, %c4 : i32
+    %12 = arith.cmpi sgt, %1, %c4 : i32
+    func.return
+}
+
+// CHECK-LABEL: func @preserves_structure
+// CHECK: scf.for %[[arg1:.*]] =
+// CHECK: %[[v:.*]] = arith.remui %[[arg1]]
+// CHECK: %[[w:.*]] = arith.addi %[[v]], %[[v]]
+// CHECK: %[[test:.*]] = arith.cmpi ule, %[[w]]
+// CHECK: scf.if %[[test]]
+// CHECK: memref.store %[[w]]
+func.func @preserves_structure(%arg0 : memref<8xindex>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c8 = arith.constant 8 : index
+    scf.for %arg1 = %c0 to %c8 step %c1 {
+        %v = arith.remsi %arg1, %c4 : index
+        %w = arith.addi %v, %v : index
+        %test = arith.cmpi sle, %w, %c4 : index
+        scf.if %test {
+            memref.store %w, %arg0[%arg1] : memref<8xindex>
+        }
+    }
+    func.return
+}


        


More information about the Mlir-commits mailing list