[Mlir-commits] [mlir] d471c85 - [mlir][int-range] Update int range inference for `arith.xori` (#117272)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 26 02:50:47 PST 2024


Author: Ivan Butygin
Date: 2024-11-26T13:50:43+03:00
New Revision: d471c85e654ad0111cdffe588b2b958b62eca29f

URL: https://github.com/llvm/llvm-project/commit/d471c85e654ad0111cdffe588b2b958b62eca29f
DIFF: https://github.com/llvm/llvm-project/commit/d471c85e654ad0111cdffe588b2b958b62eca29f.diff

LOG: [mlir][int-range] Update int range inference for `arith.xori` (#117272)

Previous impl was getting incorrect results for widths > i1 and was
disabled.

While same algorithm can be used for `andi` and `ori` too, without
additional modifications it will produce less precise result.

Added: 
    

Modified: 
    mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
    mlir/test/Dialect/Arith/int-range-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2c1276d577a55b..7a73a94201f1d6 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
                   /*isSigned=*/false);
 }
 
+/// Get bitmask of all bits which can change while iterating in
+/// [bound.umin(), bound.umax()].
+static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
+  APInt leftVal = bound.umin(), rightVal = bound.umax();
+  unsigned bitwidth = leftVal.getBitWidth();
+  unsigned 
diff eringBits = bitwidth - (leftVal ^ rightVal).countl_zero();
+  return APInt::getLowBitsSet(bitwidth, 
diff eringBits);
+}
+
 ConstantIntRanges
 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
-  // TODO: The code below doesn't work for bitwidths > i1.
-  // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
-  // widenBitwiseBounds will produce:
-  // lhs:
-  // 2060639848  01111010110100101101111001101000
-  // 2060639851  01111010110100101101111001101011
-  // rhs:
-  // 2060639849  01111010110100101101111001101001
-  // 2060639849  01111010110100101101111001101001
-  // None of those combinations xor to 0, while intermediate values does.
-  unsigned width = argRanges[0].umin().getBitWidth();
-  if (width > 1)
-    return ConstantIntRanges::maxRange(width);
-
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a ^ b;
-  };
-  return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  // Construct mask of varying bits for both ranges, xor values and then replace
+  // masked bits with 0s and 1s to get min and max values respectively.
+  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
+  APInt res = lhs.umin() ^ rhs.umin();
+  APInt min = res & ~mask;
+  APInt max = res | mask;
+  return ConstantIntRanges::fromUnsigned(min, max);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4db846fa4656a3..48a3eb20eb7fb0 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
 }
 
 // CHECK-LABEL: func @xori
-// TODO: xor folding is temporarily disabled
-// CHECK-NOT: arith.constant false
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
 func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
     %c0 = arith.constant 0 : i64
     %c7 = arith.constant 7 : i64


        


More information about the Mlir-commits mailing list