[llvm] f35e3fa - Add transforms for `(max/min (xor X, Pow2), X)` -> `(and/or X, Pow2/~Pow2)`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 24 13:22:30 PST 2023


Author: Noah Goldstein
Date: 2023-02-24T15:22:09-06:00
New Revision: f35e3fa53bb7173a8f8ccda8eb017a7ccd986800

URL: https://github.com/llvm/llvm-project/commit/f35e3fa53bb7173a8f8ccda8eb017a7ccd986800
DIFF: https://github.com/llvm/llvm-project/commit/f35e3fa53bb7173a8f8ccda8eb017a7ccd986800.diff

LOG: Add transforms for `(max/min (xor X, Pow2), X)` -> `(and/or X, Pow2/~Pow2)`

X ^ Pow2 is guranteed to flip one bit. We can use this to speedup
max/min by just selecting X with/without (or/andnot) the flipped bit
respectively.

Alive2 Links:
smax-neg: https://alive2.llvm.org/ce/z/j3QYFs
smin-neg: https://alive2.llvm.org/ce/z/bFYnQW
smax-pos: https://alive2.llvm.org/ce/z/4xYSxR
smin-pos: https://alive2.llvm.org/ce/z/H3RPKj
umax    : https://alive2.llvm.org/ce/z/P4oRcX
umin    : https://alive2.llvm.org/ce/z/vWZG6p

Differential Revision: https://reviews.llvm.org/D144606

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index f6ae6741a317..b585de05e9ea 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1468,6 +1468,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       }
     }
 
+    // (umax X, (xor X, Pow2))
+    //      -> (or X, Pow2)
+    // (umin X, (xor X, Pow2))
+    //      -> (and X, ~Pow2)
+    // (smax X, (xor X, Pos_Pow2))
+    //      -> (or X, Pos_Pow2)
+    // (smin X, (xor X, Pos_Pow2))
+    //      -> (and X, ~Pos_Pow2)
+    // (smax X, (xor X, Neg_Pow2))
+    //      -> (and X, ~Neg_Pow2)
+    // (smin X, (xor X, Neg_Pow2))
+    //      -> (or X, Neg_Pow2)
+    if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) ||
+         match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) &&
+        isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) {
+      bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax;
+      bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin;
+
+      if (IID == Intrinsic::smax || IID == Intrinsic::smin) {
+        auto KnownSign = getKnownSign(X, II, DL, &AC, &DT);
+        if (KnownSign == std::nullopt) {
+          UseOr = false;
+          UseAndN = false;
+        } else if (*KnownSign /* true is Signed. */) {
+          UseOr ^= true;
+          UseAndN ^= true;
+          Type *Ty = I0->getType();
+          // Negative power of 2 must be IntMin. It's possible to be able to
+          // prove negative / power of 2 without actually having known bits, so
+          // just get the value by hand.
+          X = Constant::getIntegerValue(
+              Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits()));
+        }
+      }
+      if (UseOr)
+        return BinaryOperator::CreateOr(I0, X);
+      else if (UseAndN)
+        return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X));
+    }
+
     // If we can eliminate ~A and Y is free to invert:
     // max ~A, Y --> ~(min A, ~Y)
     //

diff  --git a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
index 077e913891f9..ff3093cd183e 100644
--- a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
@@ -15,8 +15,7 @@ declare void @barrier()
 
 define <2 x i8> @umax_xor_Cpow2(<2 x i8> %x) {
 ; CHECK-LABEL: @umax_xor_Cpow2(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
-; CHECK-NEXT:    [[R:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = or <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %x_xor = xor <2 x i8> %x, <i8 128, i8 128>
@@ -26,8 +25,7 @@ define <2 x i8> @umax_xor_Cpow2(<2 x i8> %x) {
 
 define i8 @umin_xor_Cpow2(i8 %x) {
 ; CHECK-LABEL: @umin_xor_Cpow2(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[X:%.*]], 64
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.umin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[X:%.*]], -65
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %x_xor = xor i8 %x, 64
@@ -37,8 +35,7 @@ define i8 @umin_xor_Cpow2(i8 %x) {
 
 define i8 @smax_xor_Cpow2_pos(i8 %x) {
 ; CHECK-LABEL: @smax_xor_Cpow2_pos(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[X:%.*]], 32
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[X:%.*]], 32
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %x_xor = xor i8 %x, 32
@@ -48,8 +45,7 @@ define i8 @smax_xor_Cpow2_pos(i8 %x) {
 
 define <2 x i8> @smin_xor_Cpow2_pos(<2 x i8> %x) {
 ; CHECK-LABEL: @smin_xor_Cpow2_pos(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 16, i8 16>
-; CHECK-NEXT:    [[R:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[X:%.*]], <i8 -17, i8 -17>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %x_xor = xor <2 x i8> %x, <i8 16, i8 16>
@@ -59,8 +55,7 @@ define <2 x i8> @smin_xor_Cpow2_pos(<2 x i8> %x) {
 
 define <2 x i8> @smax_xor_Cpow2_neg(<2 x i8> %x) {
 ; CHECK-LABEL: @smax_xor_Cpow2_neg(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
-; CHECK-NEXT:    [[R:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[X:%.*]], <i8 127, i8 127>
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %x_xor = xor <2 x i8> %x, <i8 128, i8 128>
@@ -70,8 +65,7 @@ define <2 x i8> @smax_xor_Cpow2_neg(<2 x i8> %x) {
 
 define i8 @smin_xor_Cpow2_neg(i8 %x) {
 ; CHECK-LABEL: @smin_xor_Cpow2_neg(
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[X:%.*]], -128
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[X:%.*]], -128
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %x_xor = xor i8 %x, 128
@@ -83,8 +77,7 @@ define i8 @umax_xor_pow2(i8 %x, i8 %y) {
 ; CHECK-LABEL: @umax_xor_pow2(
 ; CHECK-NEXT:    [[NY:%.*]] = sub i8 0, [[Y:%.*]]
 ; CHECK-NEXT:    [[YP2:%.*]] = and i8 [[NY]], [[Y]]
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = or i8 [[YP2]], [[X:%.*]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ;
   %ny = sub i8 0, %y
@@ -98,8 +91,8 @@ define <2 x i8> @umin_xor_pow2(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @umin_xor_pow2(
 ; CHECK-NEXT:    [[NY:%.*]] = sub <2 x i8> zeroinitializer, [[Y:%.*]]
 ; CHECK-NEXT:    [[YP2:%.*]] = and <2 x i8> [[NY]], [[Y]]
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor <2 x i8> [[YP2]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i8> [[YP2]], <i8 -1, i8 -1>
+; CHECK-NEXT:    [[R:%.*]] = and <2 x i8> [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret <2 x i8> [[R]]
 ;
   %ny = sub <2 x i8> <i8 0, i8 0>, %y
@@ -146,8 +139,7 @@ define i8 @smax_xor_pow2_neg(i8 %x, i8 %y) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[YP2]], 0
 ; CHECK-NEXT:    br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]]
 ; CHECK:       neg:
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[X:%.*]], 127
 ; CHECK-NEXT:    ret i8 [[R]]
 ; CHECK:       pos:
 ; CHECK-NEXT:    call void @barrier()
@@ -173,8 +165,8 @@ define i8 @smin_xor_pow2_pos(i8 %x, i8 %y) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[YP2]], 0
 ; CHECK-NEXT:    br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]]
 ; CHECK:       neg:
-; CHECK-NEXT:    [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i8 [[YP2]], -1
+; CHECK-NEXT:    [[R:%.*]] = and i8 [[TMP1]], [[X:%.*]]
 ; CHECK-NEXT:    ret i8 [[R]]
 ; CHECK:       pos:
 ; CHECK-NEXT:    call void @barrier()


        


More information about the llvm-commits mailing list