[Mlir-commits] [mlir] [mlir] Add `arith-int-range-narrowing` pass (PR #112404)
Jakub Kuderski
llvmlistbot at llvm.org
Sat Nov 2 17:02:45 PDT 2024
================
@@ -190,8 +195,244 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};
-struct IntRangeOptimizationsPass
- : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
+/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
+static Type checkIntType(Type type, unsigned targetBitwidth) {
+ Type elemType = getElementTypeOrSelf(type);
+ if (isa<IndexType>(elemType))
+ return type;
+
+ if (auto intType = dyn_cast<IntegerType>(elemType))
+ if (intType.getWidth() > targetBitwidth)
+ return type;
+
+ return nullptr;
+}
+
+/// Check if op have same type for all operands and results and this type
+/// is suitable for truncation.
+/// Retuns args type or empty.
+static Type checkElementwiseOpType(Operation *op, unsigned targetBitwidth) {
+ if (op->getNumOperands() == 0 || op->getNumResults() == 0)
+ return nullptr;
+
+ Type type;
+ for (auto range :
+ {ValueRange(op->getOperands()), ValueRange(op->getResults())}) {
+ for (Value val : range) {
+ if (!type) {
+ type = val.getType();
+ continue;
+ }
+
+ if (type != val.getType())
+ return nullptr;
+ }
+ }
+
+ return checkIntType(type, targetBitwidth);
+}
+
+/// Return union of all operands values ranges.
+static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
+ ValueRange operands) {
+ std::optional<ConstantIntRanges> ret;
+ for (Value value : operands) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return std::nullopt;
+
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+
+ if (!ret) {
+ ret = inferredRange;
+ } else {
+ ret = ret->rangeUnion(inferredRange);
+ }
+ }
+ return ret;
+}
+
+/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
+/// return shaped type as well.
+static Type getTargetType(Type srcType, unsigned targetBitwidth) {
+ auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
+ if (auto shaped = dyn_cast<ShapedType>(srcType))
+ return shaped.clone(dstType);
+
+ assert(srcType.isIntOrIndex() && "Invalid src type");
+ return dstType;
+}
+
+/// Check privided `range` is inside `smin, smax, umin, umax` bounds.
+static bool checkRange(const ConstantIntRanges &range, APInt smin, APInt smax,
+ APInt umin, APInt umax) {
+ auto sge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sge(val2);
+ };
+ auto sle = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.sext(width);
+ val2 = val2.sext(width);
+ return val1.sle(val2);
+ };
+ auto uge = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.uge(val2);
+ };
+ auto ule = [](APInt val1, APInt val2) -> bool {
+ unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
+ val1 = val1.zext(width);
+ val2 = val2.zext(width);
+ return val1.ule(val2);
+ };
+ return sge(range.smin(), smin) && sle(range.smax(), smax) &&
+ uge(range.umin(), umin) && ule(range.umax(), umax);
+}
+
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+ Type srcType = src.getType();
+ assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
+ "Mixing vector and non-vector types");
+ Type srcElemType = getElementTypeOrSelf(srcType);
+ Type dstElemType = getElementTypeOrSelf(dstType);
+ assert(srcElemType.isIntOrIndex() && "Invalid src type");
+ assert(dstElemType.isIntOrIndex() && "Invalid dst type");
+ if (srcType == dstType)
+ return src;
+
+ if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
+ return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+
+ auto srcInt = cast<IntegerType>(srcElemType);
+ auto dstInt = cast<IntegerType>(dstElemType);
+ if (dstInt.getWidth() < srcInt.getWidth())
+ return builder.create<arith::TruncIOp>(loc, dstType, src);
+
+ return builder.create<arith::ExtUIOp>(loc, dstType, src);
+}
+
+struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
+ NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
+ ArrayRef<unsigned> target)
+ : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
----------------
kuhar wrote:
```suggestion
: OpTraitRewritePattern(context), solver(s),
```
?
https://github.com/llvm/llvm-project/pull/112404
More information about the Mlir-commits
mailing list