[llvm] [InstCombine] (uitofp bool X) * Y --> X ? Y : 0 (PR #96216)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 11:29:33 PDT 2024


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/96216

>From ff28bc6d54e59c42066ccee21c1e8f8a89c9ef05 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Thu, 13 Jun 2024 23:40:59 +0000
Subject: [PATCH 1/2] [InstCombine] (sitofp bool X) * Y --> X ? Y : 0

---
 .../InstCombine/InstCombineMulDivRem.cpp      | 10 ++++++
 llvm/test/Transforms/InstCombine/fmul-bool.ll | 35 +++++++++++++++++++
 2 files changed, 45 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/fmul-bool.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 8fcb3544f682a..dc2c92a84c650 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -879,6 +879,16 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
     if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
       return BinaryOperator::CreateFMulFMF(X, NegC, &I);
 
+  if (I.hasNoNaNs() && I.hasNoSignedZeros()) {
+    // (uitofp bool X) * Y --> X ? Y : 0
+    // Y * (uitofp bool X) --> X ? Y : 0
+    // Note INF * 0 is NaN.
+    if (match(Op0, m_UIToFP(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
+      return SelectInst::Create(X, Op1, ConstantFP::get(I.getType(), 0.0));
+    if (match(Op1, m_UIToFP(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
+      return SelectInst::Create(X, Op0, ConstantFP::get(I.getType(), 0.0));
+  }
+
   // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E)
   if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
     return replaceInstUsesWith(I, V);
diff --git a/llvm/test/Transforms/InstCombine/fmul-bool.ll b/llvm/test/Transforms/InstCombine/fmul-bool.ll
new file mode 100644
index 0000000000000..73bb4f39b106b
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fmul-bool.ll
@@ -0,0 +1,35 @@
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+; X * Y (when Y is a boolean) --> Y ? X : 0
+
+define float @fmul_bool(float %x, i1 %y) {
+; CHECK-LABEL: @fmul_bool(
+; CHECK-NEXT:    [[M:%.*]] = select i1 [[Y:%.*]], float [[X:%.*]], float 0.000000e+00
+; CHECK-NEXT:    ret float [[M]]
+;
+  %z = uitofp i1 %y to float
+  %m = fmul nnan nsz float %z, %x
+  ret float %m
+}
+
+define <2 x float> @fmul_bool_vec(<2 x float> %x, <2 x i1> %y) {
+; CHECK-LABEL: @fmul_bool_vec(
+; CHECK-NEXT:    [[M:%.*]] = select <2 x i1> [[Y:%.*]], <2 x float> [[X:%.*]], <2 x float> zeroinitializer
+; CHECK-NEXT:    ret <2 x float> [[M]]
+;
+  %z = uitofp <2 x i1> %y to <2 x float>
+  %m = fmul nnan nsz <2 x float> %z, %x
+  ret <2 x float> %m
+}
+
+define <2 x float> @fmul_bool_vec_commute(<2 x float> %px, <2 x i1> %y) {
+; CHECK-LABEL: @fmul_bool_vec_commute(
+; CHECK-NEXT:    [[X:%.*]] = fmul nnan nsz <2 x float> [[PX:%.*]], [[PX]]
+; CHECK-NEXT:    [[M:%.*]] = select <2 x i1> [[Y:%.*]], <2 x float> [[X]], <2 x float> zeroinitializer
+; CHECK-NEXT:    ret <2 x float> [[M]]
+;
+  %x = fmul nnan nsz <2 x float> %px, %px  ; thwart complexity-based canonicalization
+  %z = uitofp <2 x i1> %y to <2 x float>
+  %m = fmul nnan nsz <2 x float> %x, %z
+  ret <2 x float> %m
+}

>From b24c2e187b9f56a878ae2a86927c794e31ff1559 Mon Sep 17 00:00:00 2001
From: Alex MacLean <amaclean at nvidia.com>
Date: Fri, 21 Jun 2024 18:29:14 +0000
Subject: [PATCH 2/2] address comments - copy fast math flags

---
 .../InstCombine/InstCombineMulDivRem.cpp         | 16 ++++++++++++----
 llvm/test/Transforms/InstCombine/fmul-bool.ll    |  6 +++---
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index dc2c92a84c650..c3f1c12d2f564 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -883,10 +883,18 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
     // (uitofp bool X) * Y --> X ? Y : 0
     // Y * (uitofp bool X) --> X ? Y : 0
     // Note INF * 0 is NaN.
-    if (match(Op0, m_UIToFP(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
-      return SelectInst::Create(X, Op1, ConstantFP::get(I.getType(), 0.0));
-    if (match(Op1, m_UIToFP(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
-      return SelectInst::Create(X, Op0, ConstantFP::get(I.getType(), 0.0));
+    if (match(Op0, m_UIToFP(m_Value(X))) &&
+        X->getType()->isIntOrIntVectorTy(1)) {
+      auto *SI = SelectInst::Create(X, Op1, ConstantFP::get(I.getType(), 0.0));
+      SI->copyFastMathFlags(I.getFastMathFlags());
+      return SI;
+    }
+    if (match(Op1, m_UIToFP(m_Value(X))) &&
+        X->getType()->isIntOrIntVectorTy(1)) {
+      auto *SI = SelectInst::Create(X, Op0, ConstantFP::get(I.getType(), 0.0));
+      SI->copyFastMathFlags(I.getFastMathFlags());
+      return SI;
+    }
   }
 
   // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E)
diff --git a/llvm/test/Transforms/InstCombine/fmul-bool.ll b/llvm/test/Transforms/InstCombine/fmul-bool.ll
index 73bb4f39b106b..263ca539669bb 100644
--- a/llvm/test/Transforms/InstCombine/fmul-bool.ll
+++ b/llvm/test/Transforms/InstCombine/fmul-bool.ll
@@ -4,7 +4,7 @@
 
 define float @fmul_bool(float %x, i1 %y) {
 ; CHECK-LABEL: @fmul_bool(
-; CHECK-NEXT:    [[M:%.*]] = select i1 [[Y:%.*]], float [[X:%.*]], float 0.000000e+00
+; CHECK-NEXT:    [[M:%.*]] = select nnan nsz i1 [[Y:%.*]], float [[X:%.*]], float 0.000000e+00
 ; CHECK-NEXT:    ret float [[M]]
 ;
   %z = uitofp i1 %y to float
@@ -14,7 +14,7 @@ define float @fmul_bool(float %x, i1 %y) {
 
 define <2 x float> @fmul_bool_vec(<2 x float> %x, <2 x i1> %y) {
 ; CHECK-LABEL: @fmul_bool_vec(
-; CHECK-NEXT:    [[M:%.*]] = select <2 x i1> [[Y:%.*]], <2 x float> [[X:%.*]], <2 x float> zeroinitializer
+; CHECK-NEXT:    [[M:%.*]] = select nnan nsz <2 x i1> [[Y:%.*]], <2 x float> [[X:%.*]], <2 x float> zeroinitializer
 ; CHECK-NEXT:    ret <2 x float> [[M]]
 ;
   %z = uitofp <2 x i1> %y to <2 x float>
@@ -25,7 +25,7 @@ define <2 x float> @fmul_bool_vec(<2 x float> %x, <2 x i1> %y) {
 define <2 x float> @fmul_bool_vec_commute(<2 x float> %px, <2 x i1> %y) {
 ; CHECK-LABEL: @fmul_bool_vec_commute(
 ; CHECK-NEXT:    [[X:%.*]] = fmul nnan nsz <2 x float> [[PX:%.*]], [[PX]]
-; CHECK-NEXT:    [[M:%.*]] = select <2 x i1> [[Y:%.*]], <2 x float> [[X]], <2 x float> zeroinitializer
+; CHECK-NEXT:    [[M:%.*]] = select nnan nsz <2 x i1> [[Y:%.*]], <2 x float> [[X]], <2 x float> zeroinitializer
 ; CHECK-NEXT:    ret <2 x float> [[M]]
 ;
   %x = fmul nnan nsz <2 x float> %px, %px  ; thwart complexity-based canonicalization



More information about the llvm-commits mailing list