[llvm] goldsteinn/itofp binop fix (PR #85298)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 12:26:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- **[InstCombine] Add test for `(fmul (sitfp x), 0)`; NFC**
- **[InstCombine] Fix behavior for `(fmul (sitfp x), 0)`**


---
Full diff: https://github.com/llvm/llvm-project/pull/85298.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/PatternMatch.h (+11) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+5) 
- (modified) llvm/test/Transforms/InstCombine/binop-itofp.ll (+28) 
- (modified) llvm/unittests/IR/PatternMatch.cpp (+63) 


``````````diff
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 487ae170210de9..49f44affbf88c6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -544,6 +544,17 @@ struct is_zero {
 /// For vectors, this includes constants with undefined elements.
 inline is_zero m_Zero() { return is_zero(); }
 
+struct is_non_zero {
+  bool isValue(const APInt &C) { return !C.isZero(); }
+};
+
+/// Match any constant s.t all elements are non-zero. For a scalar, this is the
+/// same as !m_Zero. For vectors is ensures that !m_Zero holds for all elements.
+inline cst_pred_ty<is_non_zero> m_NonZero() {
+  return cst_pred_ty<is_non_zero>();
+}
+inline api_pred_ty<is_non_zero> m_NonZero(const APInt *&V) { return V; }
+
 struct is_power2 {
   bool isValue(const APInt &C) { return C.isPowerOf2(); }
 };
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 90a18fcc125c45..13d01f83d12630 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1491,6 +1491,11 @@ Instruction *InstCombinerImpl::foldFBinOpOfIntCastsFromSign(
                                 Op1IntC, FPTy, DL) != Op1FpC)
       return nullptr;
 
+    // Signed + Mul req non-zero
+    if (OpsFromSigned && BO.getOpcode() == Instruction::FMul &&
+        !match(Op1IntC, m_NonZero()))
+      return nullptr;
+
     // First try to keep sign of cast the same.
     IntOps[1] = Op1IntC;
   }
diff --git a/llvm/test/Transforms/InstCombine/binop-itofp.ll b/llvm/test/Transforms/InstCombine/binop-itofp.ll
index f796273c84e082..fcb116afb6448c 100644
--- a/llvm/test/Transforms/InstCombine/binop-itofp.ll
+++ b/llvm/test/Transforms/InstCombine/binop-itofp.ll
@@ -1004,3 +1004,31 @@ define float @test_ui_add_with_signed_constant(i32 %shr.i) {
   %add = fadd float %sub, -16383.0
   ret float %add
 }
+
+
+;; Reduced form of bug noticed due to #82555
+ at g_12 = global i1 false
+ at g_2345 = global i32 1
+define i32 @missed_nonzero_check_on_constant_for_si_fmul(ptr %g_12, i1 %.b, ptr %g_2345) {
+; CHECK-LABEL: @missed_nonzero_check_on_constant_for_si_fmul(
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[DOTB:%.*]], i32 65529, i32 53264
+; CHECK-NEXT:    [[CONV_I:%.*]] = trunc i32 [[SEL]] to i16
+; CHECK-NEXT:    [[CONV1_I:%.*]] = sitofp i16 [[CONV_I]] to float
+; CHECK-NEXT:    [[MUL3_I_I:%.*]] = fmul float [[CONV1_I]], 0.000000e+00
+; CHECK-NEXT:    store i32 [[SEL]], ptr [[G_2345:%.*]], align 4
+; CHECK-NEXT:    [[A_0_COPYLOAD_CAST:%.*]] = bitcast float [[MUL3_I_I]] to i32
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[A_0_COPYLOAD_CAST]], -1
+; CHECK-NEXT:    [[CONV:%.*]] = zext i1 [[CMP]] to i32
+; CHECK-NEXT:    ret i32 [[CONV]]
+;
+  %.b1 = load i1, ptr %g_12, align 4
+  %sel = select i1 %.b, i32 65529, i32 53264
+  %conv.i = trunc i32 %sel to i16
+  %conv1.i = sitofp i16 %conv.i to float
+  %mul3.i.i = fmul float %conv1.i, 0.000000e+00
+  store i32 %sel, ptr %g_2345, align 4
+  %a.0.copyload.cast = bitcast float %mul3.i.i to i32
+  %cmp = icmp sgt i32 %a.0.copyload.cast, -1
+  %conv = zext i1 %cmp to i32
+  ret i32 %conv
+}
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..63c7c4ca57d221 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -614,6 +614,69 @@ TEST_F(PatternMatchTest, Power2) {
   EXPECT_TRUE(m_NegatedPower2OrZero().match(CZero));
 }
 
+TEST_F(PatternMatchTest, NonZero) {
+  EXPECT_FALSE(m_NonZero().match(IRB.getInt32(0)));
+  EXPECT_TRUE(m_NonZero().match(IRB.getInt32(1)));
+
+  Type *I8Ty = IRB.getInt8Ty();
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 1));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 2));
+    EXPECT_TRUE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(ConstantInt::get(I8Ty, 0));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+
+  {
+    SmallVector<Constant *, 2> VecElemIdxs;
+    VecElemIdxs.push_back(PoisonValue::get(I8Ty));
+    VecElemIdxs.push_back(UndefValue::get(I8Ty));
+    EXPECT_FALSE(m_NonZero().match(ConstantVector::get(VecElemIdxs)));
+  }
+}
+
 TEST_F(PatternMatchTest, Not) {
   Value *C1 = IRB.getInt32(1);
   Value *C2 = IRB.getInt32(2);

``````````

</details>


https://github.com/llvm/llvm-project/pull/85298


More information about the llvm-commits mailing list