[Mlir-commits] [mlir] [mlir][int-range] Limit xor int range inference to i1 (PR #116968)

Ivan Butygin llvmlistbot at llvm.org
Wed Nov 20 05:07:36 PST 2024


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/116968

Fixes https://github.com/llvm/llvm-project/issues/82168

`intrange::inferXor` was incorrectly handling ranges for widths > i1 (see example in code). Limit it to i1 for now. For bigger ranges it will return maxRange.

>From 89cfcce0296560ad3f2413f83c07d8fb5b73a0e5 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 20 Nov 2024 14:05:41 +0100
Subject: [PATCH] [mlir][int-range] Limit xor int range inference to i1

Fixes https://github.com/llvm/llvm-project/issues/82168

`intrange::inferXor` was incorrectly handling ranges for widths > i1 (see example in code).
Limit it to i1 for now. For bigger ranges it will return maxRange.
---
 .../Interfaces/Utils/InferIntRangeCommon.cpp  | 14 +++++++++
 .../Dialect/Arith/int-range-interface.mlir    | 30 +++++++++++++++++--
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 0b085b10b2b337..2c1276d577a55b 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -558,6 +558,20 @@ mlir::intrange::inferOr(ArrayRef<ConstantIntRanges> argRanges) {
 
 ConstantIntRanges
 mlir::intrange::inferXor(ArrayRef<ConstantIntRanges> argRanges) {
+  // TODO: The code below doesn't work for bitwidths > i1.
+  // For input ranges lhs=[2060639849, 2060639850], rhs=[2060639849, 2060639849]
+  // widenBitwiseBounds will produce:
+  // lhs:
+  // 2060639848  01111010110100101101111001101000
+  // 2060639851  01111010110100101101111001101011
+  // rhs:
+  // 2060639849  01111010110100101101111001101001
+  // 2060639849  01111010110100101101111001101001
+  // None of those combinations xor to 0, while intermediate values does.
+  unsigned width = argRanges[0].umin().getBitWidth();
+  if (width > 1)
+    return ConstantIntRanges::maxRange(width);
+
   auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]);
   auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]);
   auto xori = [](const APInt &a, const APInt &b) -> std::optional<APInt> {
diff --git a/mlir/test/Dialect/Arith/int-range-interface.mlir b/mlir/test/Dialect/Arith/int-range-interface.mlir
index afb0b4929bce70..4db846fa4656a3 100644
--- a/mlir/test/Dialect/Arith/int-range-interface.mlir
+++ b/mlir/test/Dialect/Arith/int-range-interface.mlir
@@ -454,9 +454,35 @@ func.func @ori(%arg0 : i128, %arg1 : i128) -> i1 {
     func.return %2 : i1
 }
 
+// CHECK-LABEL: func @xori_issue_82168
+// arith.cmpi was erroneously folded to %false, see Issue #82168.
+// CHECK: %[[R:.*]] = arith.cmpi eq, %{{.*}}, %{{.*}} : i64
+// CHECK: return %[[R]]
+func.func @xori_issue_82168() -> i1 {
+    %c0_i64 = arith.constant 0 : i64
+    %c2060639849_i64 = arith.constant 2060639849 : i64
+    %2 = test.with_bounds { umin = 2060639849 : i64, umax = 2060639850 : i64, smin = 2060639849 : i64, smax = 2060639850 : i64 } : i64
+    %3 = arith.xori %2, %c2060639849_i64 : i64
+    %4 = arith.cmpi eq, %3, %c0_i64 : i64
+    func.return %4 : i1
+}
+
+// CHECK-LABEL: func @xori_i1
+//   CHECK-DAG: %[[true:.*]] = arith.constant true
+//   CHECK-DAG: %[[false:.*]] = arith.constant false
+//       CHECK: return %[[true]], %[[false]]
+func.func @xori_i1() -> (i1, i1) {
+    %true = arith.constant true
+    %1 = test.with_bounds { umin = 0 : i1, umax = 0 : i1, smin = 0 : i1, smax = 0 : i1 } : i1
+    %2 = test.with_bounds { umin = 1 : i1, umax = 1 : i1, smin = 1 : i1, smax = 1 : i1 } : i1
+    %3 = arith.xori %1, %true : i1
+    %4 = arith.xori %2, %true : i1
+    func.return %3, %4 : i1, i1
+}
+
 // CHECK-LABEL: func @xori
-// CHECK: %[[false:.*]] = arith.constant false
-// CHECK: return %[[false]]
+// TODO: xor folding is temporarily disabled
+// CHECK-NOT: arith.constant false
 func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 {
     %c0 = arith.constant 0 : i64
     %c7 = arith.constant 7 : i64



More information about the Mlir-commits mailing list