[Mlir-commits] [mlir] [mlir][int-range] Update int range inference for `arith.xori` (PR #117272)
Ivan Butygin
llvmlistbot at llvm.org
Fri Nov 22 06:38:10 PST 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/117272
>From effbb1f1d0a597876e6aaa110edf3a8836b3cde1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 22 Nov 2024 02:33:59 +0100
Subject: [PATCH] [mlir][int-range] Update int range inference for `arith.xori`
Previous impl was getting incorrect results for widths > i1 and was disabled.
While same algorithm can be used for `andi` and `ori` too, , without additional modifications it will produce less precise result.
---
.../Interfaces/Utils/InferIntRangeCommon.cpp | 38 +++++++++----------
.../Dialect/Arith/int-range-interface.mlir | 4 +-
2 files changed, 19 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2c1276d577a55b..7a73a94201f1d6 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
/*isSigned=*/false);
}
+/// Get bitmask of all bits which can change while iterating in
+/// [bound.umin(), bound.umax()].
+static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
+ APInt leftVal = bound.umin(), rightVal = bound.umax();
+ unsigned bitwidth = leftVal.getBitWidth();
+ unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
+ return APInt::getLowBitsSet(bitwidth, differingBits);
+}
+
ConstantIntRanges
mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
- // TODO: The code below doesn't work for bitwidths > i1.
- // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
- // widenBitwiseBounds will produce:
- // lhs:
- // 2060639848 01111010110100101101111001101000
- // 2060639851 01111010110100101101111001101011
- // rhs:
- // 2060639849 01111010110100101101111001101001
- // 2060639849 01111010110100101101111001101001
- // None of those combinations xor to 0, while intermediate values does.
- unsigned width = argRanges[0].umin().getBitWidth();
- if (width > 1)
- return ConstantIntRanges::maxRange(width);
-
- auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
- auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
- auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
- return a ^ b;
- };
- return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
- /*isSigned=*/false);
+ // Construct mask of varying bits for both ranges, xor values and then replace
+ // masked bits with 0s and 1s to get min and max values respectively.
+ ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+ APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
+ APInt res = lhs.umin() ^ rhs.umin();
+ APInt min = res & ~mask;
+ APInt max = res | mask;
+ return ConstantIntRanges::fromUnsigned(min, max);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4db846fa4656a3..48a3eb20eb7fb0 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
}
// CHECK-LABEL: func @xori
-// TODO: xor folding is temporarily disabled
-// CHECK-NOT: arith.constant false
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
%c0 = arith.constant 0 : i64
%c7 = arith.constant 7 : i64
More information about the Mlir-commits
mailing list