[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)

Jakub Kuderski llvmlistbot at llvm.org
Wed Oct 16 09:41:07 PDT 2024


================
@@ -184,6 +189,232 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
+static Type checkArithType(Type type, unsigned targetBitwidth) {
+  type = getElementTypeOrSelf(type);
+  if (isa<IndexType>(type))
+    return type;
+
+  if (auto intType = dyn_cast<IntegerType>(type))
+    if (intType.getWidth() > targetBitwidth)
+      return type;
+
+  return nullptr;
+}
+
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return nullptr;
+
+  Type type;
+  for (auto range :
+       {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+    for (Value val : range) {
+      if (!type) {
+        type = val.getType();
+        continue;
+      } else if (type != val.getType()) {
+        return nullptr;
+      }
+    }
+  }
+
+  return checkArithType(type, targetBitwidth);
+}
+
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange results) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : results) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    if (!ret) {
+      ret = inferredRange;
+    } else {
+      ret = ret->rangeUnion(inferredRange);
+    }
+  }
+  return ret;
+}
+
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+                       APInt umin, APInt umax) {
+  auto sge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sge(val2);
+  };
+  auto sle = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.sext(width);
+    val2 = val2.sext(width);
+    return val1.sle(val2);
+  };
+  auto uge = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.uge(val2);
+  };
+  auto ule = [](APInt val1, APInt val2) -> bool {
+    unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+    val1 = val1.zext(width);
+    val2 = val2.zext(width);
+    return val1.ule(val2);
+  };
+  return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+         uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+  Type srcType = src.getType();
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  assert(dstType.isIntOrIndex() && "Invalid dst type");
+  if (srcType == dstType)
+    return src;
+
+  if (isa<IndexType>(srcType) || isa<IndexType>(dstType))
+    return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+  auto srcInt = cast<IntegerType>(srcType);
+  auto dstInt = cast<IntegerType>(dstType);
+  if (dstInt.getWidth() < srcInt.getWidth()) {
+    return builder.create<arith::TruncIOp>(loc, dstType, src);
+  } else {
+    return builder.create<arith::ExtUIOp>(loc, dstType, src);
+  }
+}
+
+struct NarrowElementwise final
+    : public OpTraitRewritePattern<OpTrait::Elementwise> {
+  NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
+                    ArrayRef<unsigned> target)
+      : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
+        targetBitwidths(target) {}
+
+  using OpTraitRewritePattern::OpTraitRewritePattern;
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+
+    std::optional<ConstantIntRanges> range =
+        getOperandsRange(solver, op->getResults());
+    if (!range)
+      return failure();
+
+    for (unsigned targetBitwidth : targetBitwidths) {
+      Type srcType = checkElementwiseOpType(op, targetBitwidth);
+      if (!srcType)
+        continue;
+
+      // We are truncating op args to the desired bitwidth before the op and
+      // then extending op results back to the original width after. extui and
+      // exti will produce different results for negative values, so limit
+      // signed range to non-negative values.
+      auto smin = APInt::getZero(targetBitwidth);
+      auto smax = APInt::getSignedMaxValue(targetBitwidth);
+      auto umin = APInt::getMinValue(targetBitwidth);
+      auto umax = APInt::getMaxValue(targetBitwidth);
+      if (!checkRange(*range, smin, smax, umin, umax))
+        continue;
+
+      Type targetType = getTargetType(srcType, targetBitwidth);
+      if (targetType == srcType)
+        continue;
+
+      Location loc = op->getLoc();
+      IRMapping mapping;
+      for (Value arg : op->getOperands()) {
+        Value newArg = doCast(rewriter, loc, arg, targetType);
+        mapping.map(arg, newArg);
+      }
+
+      Operation *newOp = rewriter.clone(*op, mapping);
+      rewriter.modifyOpInPlace(newOp, [&]() {
+        for (OpResult res : newOp->getResults()) {
+          res.setType(targetType);
+        }
+      });
+      SmallVector<Value> newResults;
+      for (Value res : newOp->getResults())
+        newResults.emplace_back(doCast(rewriter, loc, res, srcType));
----------------
kuhar wrote:

nit: you can use `llvm::map_to_vector`

https://github.com/llvm/llvm-project/pull/112404


More information about the Mlir-commits mailing list