[Mlir-commits] [mlir] [mlir][int-range] Update int range inference for `arith.xori` (PR #117272)

Ivan Butygin llvmlistbot at llvm.org
Fri Nov 22 06:38:10 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/117272

>From effbb1f1d0a597876e6aaa110edf3a8836b3cde1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 22 Nov 2024 02:33:59 +0100
Subject: [PATCH] [mlir][int-range] Update int range inference for `arith.xori`

Previous impl was getting incorrect results for widths > i1 and was disabled.

While same algorithm can be used for `andi` and `ori` too, , without additional modifications it will produce less precise result.
---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 38 +++++++++----------
 .../Dialect/Arith/int-range-interface.mlir    |  4 +-
 2 files changed, 19 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2c1276d577a55b..7a73a94201f1d6 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -556,29 +556,25 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
                   /*isSigned=*/false);
 }
 
+/// Get bitmask of all bits which can change while iterating in
+/// [bound.umin(), bound.umax()].
+static APInt getVaryingBitsMask(const ConstantIntRanges &bound) {
+  APInt leftVal = bound.umin(), rightVal = bound.umax();
+  unsigned bitwidth = leftVal.getBitWidth();
+  unsigned differingBits = bitwidth - (leftVal ^ rightVal).countl_zero();
+  return APInt::getLowBitsSet(bitwidth, differingBits);
+}
+
 ConstantIntRanges
 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
-  // TODO: The code below doesn't work for bitwidths > i1.
-  // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
-  // widenBitwiseBounds will produce:
-  // lhs:
-  // 2060639848  01111010110100101101111001101000
-  // 2060639851  01111010110100101101111001101011
-  // rhs:
-  // 2060639849  01111010110100101101111001101001
-  // 2060639849  01111010110100101101111001101001
-  // None of those combinations xor to 0, while intermediate values does.
-  unsigned width = argRanges[0].umin().getBitWidth();
-  if (width > 1)
-    return ConstantIntRanges::maxRange(width);
-
-  auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
-  auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
-  auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
-    return a ^ b;
-  };
-  return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes},
-                  /*isSigned=*/false);
+  // Construct mask of varying bits for both ranges, xor values and then replace
+  // masked bits with 0s and 1s to get min and max values respectively.
+  ConstantIntRanges lhs = argRanges[0], rhs = argRanges[1];
+  APInt mask = getVaryingBitsMask(lhs) | getVaryingBitsMask(rhs);
+  APInt res = lhs.umin() ^ rhs.umin();
+  APInt min = res & ~mask;
+  APInt max = res | mask;
+  return ConstantIntRanges::fromUnsigned(min, max);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index 4db846fa4656a3..48a3eb20eb7fb0 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -481,8 +481,8 @@ func.func @xori_i1() -> (i1, i1) {
 }
 
 // CHECK-LABEL: func @xori
-// TODO: xor folding is temporarily disabled
-// CHECK-NOT: arith.constant false
+// CHECK: %[[false:.*]] = arith.constant false
+// CHECK: return %[[false]]
 func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
     %c0 = arith.constant 0 : i64
     %c7 = arith.constant 7 : i64



More information about the Mlir-commits mailing list