[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 3 19:25:41 PST 2024


================
@@ -190,8 +195,263 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
   DataFlowSolver &solver;
 };
 
-struct IntRangeOptimizationsPass
-    : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
+/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
+static bool checkIntType(Type type, unsigned targetBitwidth) {
+  Type elemType = getElementTypeOrSelf(type);
+  if (isa<IndexType>(elemType))
+    return true;
+
+  if (auto intType = dyn_cast<IntegerType>(elemType))
+    if (intType.getWidth() > targetBitwidth)
+      return true;
+
+  return false;
+}
+
+/// Check if op have same type for all operands and results and this type
+/// is suitable for truncation.
+static bool checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+  if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+    return false;
+
+  Type type;
+  for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+    if (!type) {
+      type = val.getType();
+      continue;
+    }
+
+    if (type != val.getType())
+      return false;
+  }
+
+  return checkIntType(type, targetBitwidth);
+}
+
+/// Return union of all operands values ranges.
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+                                                         ValueRange operands) {
+  std::optional<ConstantIntRanges> ret;
+  for (Value value : operands) {
+    auto *maybeInferredRange =
+        solver.lookupState<IntegerValueRangeLattice>(value);
+    if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+      return std::nullopt;
+
+    const ConstantIntRanges &inferredRange =
+        maybeInferredRange->getValue().getValue();
+
+    ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
+  }
+  return ret;
+}
+
+/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
+/// return shaped type as well.
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+  auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+  if (auto shaped = dyn_cast<ShapedType>(srcType))
+    return shaped.clone(dstType);
+
+  assert(srcType.isIntOrIndex() && "Invalid src type");
+  return dstType;
+}
+
+/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
----------------
kuhar wrote:

also here

https://github.com/llvm/llvm-project/pull/112404


More information about the Mlir-commits mailing list