[llvm] goldsteinn/itofp binop fix (PR #85298)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 12:26:26 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/85298

- **[InstCombine] Add test for `(fmul (sitfp x), 0)`; NFC**
- **[InstCombine] Fix behavior for `(fmul (sitfp x), 0)`**


>From 3bebd347ac022586d524ddd7305c70c7c1817bde Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 14 Mar 2024 14:22:47 -0500
Subject: [PATCH 1/2] [InstCombine] Add test for `(fmul (sitfp x), 0)`; NFC

---
 .../Transforms/InstCombine/binop-itofp.ll     | 22 +++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/binop-itofp.ll b/llvm/test/Transforms/InstCombine/binop-itofp.ll
index f796273c84e082..e571d0f37e284f 100644
--- a/llvm/test/Transforms/InstCombine/binop-itofp.ll
+++ b/llvm/test/Transforms/InstCombine/binop-itofp.ll
@@ -1004,3 +1004,25 @@ define float @test_ui_add_with_signed_constant(i32 %shr.i) {
   %add = fadd float %sub, -16383.0
   ret float %add
 }
+
+
+;; Reduced form of bug noticed due to #82555
+ at g_12 = global i1 false
+ at g_2345 = global i32 1
+define i32 @missed_nonzero_check_on_constant_for_si_fmul(ptr %g_12, i1 %.b, ptr %g_2345) {
+; CHECK-LABEL: @missed_nonzero_check_on_constant_for_si_fmul(
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[DOTB:%.*]], i32 65529, i32 53264
+; CHECK-NEXT:    store i32 [[SEL]], ptr [[G_2345:%.*]], align 4
+; CHECK-NEXT:    ret i32 1
+;
+  %.b1 = load i1, ptr %g_12, align 4
+  %sel = select i1 %.b, i32 65529, i32 53264
+  %conv.i = trunc i32 %sel to i16
+  %conv1.i = sitofp i16 %conv.i to float
+  %mul3.i.i = fmul float %conv1.i, 0.000000e+00
+  store i32 %sel, ptr %g_2345, align 4
+  %a.0.copyload.cast = bitcast float %mul3.i.i to i32
+  %cmp = icmp sgt i32 %a.0.copyload.cast, -1
+  %conv = zext i1 %cmp to i32
+  ret i32 %conv
+}

>From 65dacea20974b0e81f45d7a5b162547b91f1ffe7 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Thu, 14 Mar 2024 14:22:50 -0500
Subject: [PATCH 2/2] [InstCombine] Fix behavior for `(fmul (sitfp x), 0)`

Bug was introduced in #82555

We where missing check that the constant was non-zero for signed + mul
transform.
---
 llvm/include/llvm/IR/PatternMatch.h           | 11 ++++
 .../InstCombine/InstructionCombining.cpp      |  5 ++
 .../Transforms/InstCombine/binop-itofp.ll     |  8 ++-
 llvm/unittests/IR/PatternMatch.cpp            | 63 +++++++++++++++++++
 4 files changed, 86 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 487ae170210de9..49f44affbf88c6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -544,6 +544,17 @@ struct is_zero {
 /// For vectors, this includes constants with undefined elements.
 inline is_zero m_Zero() { return is_zero(); }
 
+struct is_non_zero {
+  bool isValue(const APInt &C) { return !C.isZero(); }
+};
+
+/// Match any constant s.t all elements are non-zero. For a scalar, this is the
+/// same as !m_Zero. For vectors is ensures that !m_Zero holds for all elements.
+inline cst_pred_ty<is_non_zero> m_NonZero() {
+  return cst_pred_ty<is_non_zero>();
+}
+inline api_pred_ty<is_non_zero> m_NonZero(const APInt *&V) { return V; }
+
 struct is_power2 {
   bool isValue(const APInt &C) { return C.isPowerOf2(); }
 };
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 90a18fcc125c45..13d01f83d12630 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1491,6 +1491,11 @@ Instruction *InstCombinerImpl::foldFBinOpOfIntCastsFromSign(
                                 Op1IntC, FPTy, DL) != Op1FpC)
       return nullptr;
 
+    // Signed + Mul req non-zero
+    if (OpsFromSigned && BO.getOpcode() == Instruction::FMul &&
+        !match(Op1IntC, m_NonZero()))
+      return nullptr;
+
     // First try to keep sign of cast the same.
     IntOps[1] = Op1IntC;
   }
diff --git a/llvm/test/Transforms/InstCombine/binop-itofp.ll b/llvm/test/Transforms/InstCombine/binop-itofp.ll
index e571d0f37e284f..fcb116afb6448c 100644
--- a/llvm/test/Transforms/InstCombine/binop-itofp.ll
+++ b/llvm/test/Transforms/InstCombine/binop-itofp.ll
@@ -1012,8 +1012,14 @@ define float @test_ui_add_with_signed_constant(i32 %shr.i) {
 define i32 @missed_nonzero_check_on_constant_for_si_fmul(ptr %g_12, i1 %.b, ptr %g_2345) {
 ; CHECK-LABEL: @missed_nonzero_check_on_constant_for_si_fmul(
 ; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[DOTB:%.*]], i32 65529, i32 53264
+; CHECK-NEXT:    [[CONV_I:%.*]] = trunc i32 [[SEL]] to i16
+; CHECK-NEXT:    [[CONV1_I:%.*]] = sitofp i16 [[CONV_I]] to float
+; CHECK-NEXT:    [[MUL3_I_I:%.*]] = fmul float [[CONV1_I]], 0.000000e+00
 ; CHECK-NEXT:    store i32 [[SEL]], ptr [[G_2345:%.*]], align 4
-; CHECK-NEXT:    ret i32 1
+; CHECK-NEXT:    [[A_0_COPYLOAD_CAST:%.*]] = bitcast float [[MUL3_I_I]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A_0_COPYLOAD_CAST]], -1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    ret i32 [[CONV]]
 ;
   %.b1 = load i1, ptr %g_12, align 4
   %sel = select i1 %.b, i32 65529, i32 53264
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..63c7c4ca57d221 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -614,6 +614,69 @@ TEST_F(PatternMatchTest, Power2) {
   EXPECT_TRUE(m_NegatedPower2OrZero().match(CZero));
 }
 
+TEST_F(PatternMatchTest, NonZero) {
+  EXPECT_FALSE(m_NonZero().match(IRB.getInt32(0)));
+  EXPECT_TRUE(m_NonZero().match(IRB.getInt32(1)));
+
+  Type *I8Ty = IRB.getInt8Ty();
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+}
+
 TEST_F(PatternMatchTest, Not) {
   Value *C1 = IRB.getInt32(1);
   Value *C2 = IRB.getInt32(2);



More information about the llvm-commits mailing list