[Mlir-commits] [mlir] [mlir][Arith] Let integer range narrowing handle negative values (PR #119642)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 11 17:01:04 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-arith
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
Update integer range narrowing to handle negative values.
The previous restriction to only narrowing known-non-negative values wasn't needed, as both the signed and unsigned ranges represent bounds on the values of each variable in the program ... except that one might be more accurate than the other. So, if either the signed or unsigned interpretetation of the inputs and outputs allows for integer narrowing, the narrowing is permitted.
This commit also updates the integer optimization rewrites to preserve the stae of constant-like operations and those that are narrowed so that rewrites of other operations don't lose that range information.
---
Patch is 25.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119642.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+137-118)
- (modified) mlir/test/Dialect/Arith/int-range-narrowing.mlir (+105-21)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index efc4db7e4c9961..7619964304720f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -46,6 +46,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}
+static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
+ Value newVal) {
+ assert(oldVal.getType() == newVal.getType() &&
+ "Can't copy integer ranges between different types");
+ auto *oldState = solver.lookupState<IntegerValueRangeLattice>(oldVal);
+ if (!oldState)
+ return;
+ solver.getOrCreateState<IntegerValueRangeLattice>(newVal)->join(*oldState);
+}
+
/// Patterned after SCCP
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
PatternRewriter &rewriter,
@@ -80,6 +90,7 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
if (!constOp)
return failure();
+ copyIntegerRange(solver, value, constOp->getResult(0));
rewriter.replaceAllUsesWith(value, constOp->getResult(0));
return success();
}
@@ -195,56 +206,21 @@ struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DataFlowSolver &solver;
};
-/// Check if `type` is index or integer type with `getWidth() > targetBitwidth`.
-static LogicalResult checkIntType(Type type, unsigned targetBitwidth) {
- Type elemType = getElementTypeOrSelf(type);
- if (isa<IndexType>(elemType))
- return success();
-
- if (auto intType = dyn_cast<IntegerType>(elemType))
- if (intType.getWidth() > targetBitwidth)
- return success();
-
- return failure();
-}
-
-/// Check if op have same type for all operands and results and this type
-/// is suitable for truncation.
-static LogicalResult checkElementwiseOpType(Operation *op,
- unsigned targetBitwidth) {
- if (op->getNumOperands() == 0 || op->getNumResults() == 0)
- return failure();
-
- Type type;
- for (Value val : llvm::concat<Value>(op->getOperands(), op->getResults())) {
- if (!type) {
- type = val.getType();
- continue;
- }
-
- if (type != val.getType())
- return failure();
- }
-
- return checkIntType(type, targetBitwidth);
-}
-
-/// Return union of all operands values ranges.
-static std::optional<ConstantIntRanges> getOperandsRange(DataFlowSolver &solver,
- ValueRange operands) {
- std::optional<ConstantIntRanges> ret;
- for (Value value : operands) {
+/// Gather ranges for all the values in `values`. Appends to the existing
+/// vector.
+static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
+ SmallVectorImpl<ConstantIntRanges> &ranges) {
+ for (Value val : values) {
auto *maybeInferredRange =
- solver.lookupState<IntegerValueRangeLattice>(value);
+ solver.lookupState<IntegerValueRangeLattice>(val);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
- return std::nullopt;
+ return failure();
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
-
- ret = (ret ? ret->rangeUnion(inferredRange) : inferredRange);
+ ranges.push_back(inferredRange);
}
- return ret;
+ return success();
}
/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
@@ -258,41 +234,59 @@ static Type getTargetType(Type srcType, unsigned targetBitwidth) {
return dstType;
}
-/// Check provided `range` is inside `smin, smax, umin, umax` bounds.
-static LogicalResult checkRange(const ConstantIntRanges &range, APInt smin,
- APInt smax, APInt umin, APInt umax) {
- auto sge = [](APInt val1, APInt val2) -> bool {
- unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
- val1 = val1.sext(width);
- val2 = val2.sext(width);
- return val1.sge(val2);
- };
- auto sle = [](APInt val1, APInt val2) -> bool {
- unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
- val1 = val1.sext(width);
- val2 = val2.sext(width);
- return val1.sle(val2);
- };
- auto uge = [](APInt val1, APInt val2) -> bool {
- unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
- val1 = val1.zext(width);
- val2 = val2.zext(width);
- return val1.uge(val2);
- };
- auto ule = [](APInt val1, APInt val2) -> bool {
- unsigned width = std::max(val1.getBitWidth(), val2.getBitWidth());
- val1 = val1.zext(width);
- val2 = val2.zext(width);
- return val1.ule(val2);
- };
- return success(sge(range.smin(), smin) && sle(range.smax(), smax) &&
- uge(range.umin(), umin) && ule(range.umax(), umax));
+namespace {
+// Enum for tracking which type of truncation should be performed
+// to narrow an operation, if any.
+enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
+} // namespace
+
+/// If the values within `range` can be represented using only `width` bits,
+/// return the kind of truncation needed to preserve that property.
+///
+/// This check relies on the fact that the signed and unsigned ranges are both
+/// always correct, but that one might be an approximation of the other,
+/// so we want to use the correct truncation operation.
+static CastKind checkTruncatability(const ConstantIntRanges &range,
+ unsigned targetWidth) {
+ unsigned srcWidth = range.smin().getBitWidth();
+ if (srcWidth <= targetWidth)
+ return CastKind::None;
+ unsigned removedWidth = srcWidth - targetWidth;
+ // The sign bits need to extend into the sign bit of the target width. For
+ // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
+ // bits.
+ bool canTruncateSigned =
+ range.smin().getNumSignBits() >= (removedWidth + 1) &&
+ range.smax().getNumSignBits() >= (removedWidth + 1);
+ bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
+ range.umax().countLeadingZeros() >= removedWidth;
+ if (canTruncateSigned && canTruncateUnsigned)
+ return CastKind::Both;
+ if (canTruncateSigned)
+ return CastKind::Signed;
+ if (canTruncateUnsigned)
+ return CastKind::Unsigned;
+ return CastKind::None;
+}
+
+static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
+ if (lhs == CastKind::None || rhs == CastKind::None)
+ return CastKind::None;
+ if (lhs == CastKind::Both)
+ return rhs;
+ if (rhs == CastKind::Both)
+ return lhs;
+ if (lhs == rhs)
+ return lhs;
+ return CastKind::None;
}
-static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
+static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
+ CastKind castKind) {
Type srcType = src.getType();
assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
"Mixing vector and non-vector types");
+ assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
Type srcElemType = getElementTypeOrSelf(srcType);
Type dstElemType = getElementTypeOrSelf(dstType);
assert(srcElemType.isIntOrIndex() && "Invalid src type");
@@ -300,14 +294,19 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
if (srcType == dstType)
return src;
- if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType))
+ if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) {
+ if (castKind == CastKind::Signed)
+ return builder.create<arith::IndexCastOp>(loc, dstType, src);
return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
+ }
auto srcInt = cast<IntegerType>(srcElemType);
auto dstInt = cast<IntegerType>(dstElemType);
if (dstInt.getWidth() < srcInt.getWidth())
return builder.create<arith::TruncIOp>(loc, dstType, src);
+ if (castKind == CastKind::Signed)
+ return builder.create<arith::ExtSIOp>(loc, dstType, src);
return builder.create<arith::ExtUIOp>(loc, dstType, src);
}
@@ -319,36 +318,46 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- std::optional<ConstantIntRanges> range =
- getOperandsRange(solver, op->getResults());
- if (!range)
- return failure();
+ SmallVector<ConstantIntRanges> ranges;
+ if (failed(collectRanges(solver, op->getOperands(), ranges)))
+ return rewriter.notifyMatchFailure(op, "input without specified range");
+ if (failed(collectRanges(solver, op->getResults(), ranges)))
+ return rewriter.notifyMatchFailure(op, "output without specified range");
+
+ if (op->getNumResults() == 0)
+ return rewriter.notifyMatchFailure(op, "can't narrow resultless op");
+ Type srcType = op->getResult(0).getType();
+ if (!llvm::all_equal(op->getResultTypes()))
+ return rewriter.notifyMatchFailure(op, "mismatched result types");
+ if (op->getNumOperands() == 0 ||
+ !llvm::all_of(op->getOperandTypes(),
+ [=](Type t) { return t == srcType; }))
+ return rewriter.notifyMatchFailure(
+ op, "no operands or operand types don't match result type");
for (unsigned targetBitwidth : targetBitwidths) {
- if (failed(checkElementwiseOpType(op, targetBitwidth)))
- continue;
-
- Type srcType = op->getResult(0).getType();
-
- // We are truncating op args to the desired bitwidth before the op and
- // then extending op results back to the original width after. extui and
- // exti will produce different results for negative values, so limit
- // signed range to non-negative values.
- auto smin = APInt::getZero(targetBitwidth);
- auto smax = APInt::getSignedMaxValue(targetBitwidth);
- auto umin = APInt::getMinValue(targetBitwidth);
- auto umax = APInt::getMaxValue(targetBitwidth);
- if (failed(checkRange(*range, smin, smax, umin, umax)))
+ CastKind castKind = CastKind::Both;
+ for (const ConstantIntRanges &range : ranges) {
+ castKind = mergeCastKinds(castKind,
+ checkTruncatability(range, targetBitwidth));
+ if (castKind == CastKind::None)
+ break;
+ }
+ if (castKind == CastKind::None)
continue;
-
Type targetType = getTargetType(srcType, targetBitwidth);
if (targetType == srcType)
continue;
Location loc = op->getLoc();
IRMapping mapping;
- for (Value arg : op->getOperands()) {
- Value newArg = doCast(rewriter, loc, arg, targetType);
+ for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) {
+ CastKind argCastKind = castKind;
+ // When dealing with `index` values, preserve non-negativity in the
+ // index_casts since we can't recover this in unsigned when equivalent.
+ if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
+ argCastKind = CastKind::Both;
+ Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind);
mapping.map(arg, newArg);
}
@@ -359,8 +368,12 @@ struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
}
});
SmallVector<Value> newResults;
- for (Value res : newOp->getResults())
- newResults.emplace_back(doCast(rewriter, loc, res, srcType));
+ for (auto [newRes, oldRes] :
+ llvm::zip_equal(newOp->getResults(), op->getResults())) {
+ Value castBack = doCast(rewriter, loc, newRes, srcType, castKind);
+ copyIntegerRange(solver, oldRes, castBack);
+ newResults.push_back(castBack);
+ }
rewriter.replaceOp(op, newResults);
return success();
@@ -382,21 +395,19 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
- std::optional<ConstantIntRanges> range =
- getOperandsRange(solver, {lhs, rhs});
- if (!range)
+ SmallVector<ConstantIntRanges> ranges;
+ if (failed(collectRanges(solver, op.getOperands(), ranges)))
return failure();
+ const ConstantIntRanges &lhsRange = ranges[0];
+ const ConstantIntRanges &rhsRange = ranges[1];
+ Type srcType = lhs.getType();
for (unsigned targetBitwidth : targetBitwidths) {
- Type srcType = lhs.getType();
- if (failed(checkIntType(srcType, targetBitwidth)))
- continue;
-
- auto smin = APInt::getSignedMinValue(targetBitwidth);
- auto smax = APInt::getSignedMaxValue(targetBitwidth);
- auto umin = APInt::getMinValue(targetBitwidth);
- auto umax = APInt::getMaxValue(targetBitwidth);
- if (failed(checkRange(*range, smin, smax, umin, umax)))
+ CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth);
+ CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth);
+ CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind);
+ // Note: this includes target width > src width.
+ if (castKind == CastKind::None)
continue;
Type targetType = getTargetType(srcType, targetBitwidth);
@@ -406,11 +417,12 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
Location loc = op->getLoc();
IRMapping mapping;
for (Value arg : op->getOperands()) {
- Value newArg = doCast(rewriter, loc, arg, targetType);
+ Value newArg = doCast(rewriter, loc, arg, targetType, castKind);
mapping.map(arg, newArg);
}
Operation *newOp = rewriter.clone(*op, mapping);
+ copyIntegerRange(solver, op.getResult(), newOp->getResult(0));
rewriter.replaceOp(op, newOp->getResults());
return success();
}
@@ -425,19 +437,24 @@ struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
/// This pattern assumes all passed `targetBitwidths` are not wider than index
/// type.
-struct FoldIndexCastChain final : OpRewritePattern<arith::IndexCastUIOp> {
+template <typename CastOp>
+struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
- : OpRewritePattern(context), targetBitwidths(target) {}
+ : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
- LogicalResult matchAndRewrite(arith::IndexCastUIOp op,
+ LogicalResult matchAndRewrite(CastOp op,
PatternRewriter &rewriter) const override {
- auto srcOp = op.getIn().getDefiningOp<arith::IndexCastUIOp>();
+ auto srcOp = op.getIn().template getDefiningOp<CastOp>();
if (!srcOp)
- return failure();
+ return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
+ ;
Value src = srcOp.getIn();
if (src.getType() != op.getType())
- return failure();
+ return rewriter.notifyMatchFailure(op, "outer types don't match");
+
+ if (!srcOp.getType().isIndex())
+ return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
auto intType = dyn_cast<IntegerType>(op.getType());
if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
@@ -517,7 +534,9 @@ void mlir::arith::populateIntRangeNarrowingPatterns(
ArrayRef<unsigned> bitwidthsSupported) {
patterns.add<NarrowElementwise, NarrowCmpI>(patterns.getContext(), solver,
bitwidthsSupported);
- patterns.add<FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported);
+ patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
+ FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
+ bitwidthsSupported);
}
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
index 8893f299177ceb..ac0132480e4e18 100644
--- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir
+++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir
@@ -4,9 +4,14 @@
// Some basic tests
//===----------------------------------------------------------------------===//
-// Do not truncate negative values
+// Truncate possibly-negative values in a signed way
// CHECK-LABEL: func @test_addi_neg
-// CHECK: %[[RES:.*]] = arith.addi %{{.*}}, %{{.*}} : index
+// CHECK: %[[POS:.*]] = test.with_bounds {smax = 1 : index, smin = 0 : index, umax = 1 : index, umin = 0 : index} : index
+// CHECK: %[[NEG:.*]] = test.with_bounds {smax = 0 : index, smin = -1 : index, umax = -1 : index, umin = 0 : index} : index
+// CHECK: %[[POS_I8:.*]] = arith.index_castui %[[POS]] : index to i8
+// CHECK: %[[NEG_I8:.*]] = arith.index_cast %[[NEG]] : index to i8
+// CHECK: %[[RES_I8:.*]] = arith.addi %[[POS_I8]], %[[NEG_I8]] : i8
+// CHECK: %[[RES:.*]] = arith.index_cast %[[RES_I8]] : i8 to index
// CHECK: return %[[RES]] : index
func.func @test_addi_neg() -> index {
%0 = test.with_bounds { umin = 0 : index, umax = 1 : index, smin = 0 : index, smax = 1 : index } : index
@@ -146,14 +151,18 @@ func.func @addi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
return %r : i32
}
-// This case should not get optimized because of mixed extensions.
+// This can be optimized to i16 since we're dealing in [-128, 127] + [0, 255],
+// which is [-128, 382]
//
// CHECK-LABEL: func.func @addi_mixed_ext_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT: return %[[ADD]] : i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[ADD:.+]] = arith.addi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[ADD]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
func.func @addi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extsi %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
@@ -181,15 +190,15 @@ func.func @addi_extsi_i16(%lhs: i8, %rhs: i8) -> i16 {
// arith.subi
//===----------------------------------------------------------------------===//
-// This patterns should only apply to `arith.subi` ops with sign-extended
-// arguments.
-//
// CHECK-LABEL: func.func @subi_extui_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extui %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT: return %[[SUB]] : i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-NEXT: %[[SUB:.+]] = arith.subi %[[LHS]], %[[RHS]] : i16
+// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[SUB]] : i16 to i32
+// CHECK-NEXT: return %[[RET]] : i32
func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
%a = arith.extui %lhs : i8 to i32
%b = arith.extui %rhs : i8 to i32
@@ -197,14 +206,17 @@ func.func @subi_extui_i8(%lhs: i8, %rhs: i8) -> i32 {
return %r : i32
}
-// This case should not get optimized because of mixed extensions.
+// Despite the mixed sign and zero extensions, we can optimize here
//
// CHECK-LABEL: func.func @subi_mixed_ext_i8
// CHECK-SAME: (%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8)
// CHECK-NEXT: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32
// CHECK-NEXT: %[[EXT1:.+]] = arith.extui %[[ARG1]] : i8 to i32
-// CHECK-NEXT: %[[ADD:.+]] = arith.subi %[[EXT0]], %[[EXT1]] : i32
-// CHECK-NEXT: return %[[ADD]] : i32
+// CHECK-NEXT: %[[LHS:.+]] = arith.trunci %[[EXT0]] : i32 to i16
+// CHECK-NEXT: %[[RHS:.+]] = arith.trunci %[[EXT1]] : i32 to i16
+// CHECK-N...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/119642
More information about the Mlir-commits
mailing list