[llvm] f35e3fa - Add transforms for `(max/min (xor X, Pow2), X)` -> `(and/or X, Pow2/~Pow2)`
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 24 13:22:30 PST 2023
Author: Noah Goldstein
Date: 2023-02-24T15:22:09-06:00
New Revision: f35e3fa53bb7173a8f8ccda8eb017a7ccd986800
URL: https://github.com/llvm/llvm-project/commit/f35e3fa53bb7173a8f8ccda8eb017a7ccd986800
DIFF: https://github.com/llvm/llvm-project/commit/f35e3fa53bb7173a8f8ccda8eb017a7ccd986800.diff
LOG: Add transforms for `(max/min (xor X, Pow2), X)` -> `(and/or X, Pow2/~Pow2)`
X ^ Pow2 is guranteed to flip one bit. We can use this to speedup
max/min by just selecting X with/without (or/andnot) the flipped bit
respectively.
Alive2 Links:
smax-neg: https://alive2.llvm.org/ce/z/j3QYFs
smin-neg: https://alive2.llvm.org/ce/z/bFYnQW
smax-pos: https://alive2.llvm.org/ce/z/4xYSxR
smin-pos: https://alive2.llvm.org/ce/z/H3RPKj
umax : https://alive2.llvm.org/ce/z/P4oRcX
umin : https://alive2.llvm.org/ce/z/vWZG6p
Differential Revision: https://reviews.llvm.org/D144606
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index f6ae6741a317..b585de05e9ea 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1468,6 +1468,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
+ // (umax X, (xor X, Pow2))
+ // -> (or X, Pow2)
+ // (umin X, (xor X, Pow2))
+ // -> (and X, ~Pow2)
+ // (smax X, (xor X, Pos_Pow2))
+ // -> (or X, Pos_Pow2)
+ // (smin X, (xor X, Pos_Pow2))
+ // -> (and X, ~Pos_Pow2)
+ // (smax X, (xor X, Neg_Pow2))
+ // -> (and X, ~Neg_Pow2)
+ // (smin X, (xor X, Neg_Pow2))
+ // -> (or X, Neg_Pow2)
+ if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) ||
+ match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) &&
+ isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) {
+ bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax;
+ bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin;
+
+ if (IID == Intrinsic::smax || IID == Intrinsic::smin) {
+ auto KnownSign = getKnownSign(X, II, DL, &AC, &DT);
+ if (KnownSign == std::nullopt) {
+ UseOr = false;
+ UseAndN = false;
+ } else if (*KnownSign /* true is Signed. */) {
+ UseOr ^= true;
+ UseAndN ^= true;
+ Type *Ty = I0->getType();
+ // Negative power of 2 must be IntMin. It's possible to be able to
+ // prove negative / power of 2 without actually having known bits, so
+ // just get the value by hand.
+ X = Constant::getIntegerValue(
+ Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits()));
+ }
+ }
+ if (UseOr)
+ return BinaryOperator::CreateOr(I0, X);
+ else if (UseAndN)
+ return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X));
+ }
+
// If we can eliminate ~A and Y is free to invert:
// max ~A, Y --> ~(min A, ~Y)
//
diff --git a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
index 077e913891f9..ff3093cd183e 100644
--- a/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-of-xor-x.ll
@@ -15,8 +15,7 @@ declare void @barrier()
define <2 x i8> @umax_xor_Cpow2(<2 x i8> %x) {
; CHECK-LABEL: @umax_xor_Cpow2(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
-; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.umax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = or <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%x_xor = xor <2 x i8> %x, <i8 128, i8 128>
@@ -26,8 +25,7 @@ define <2 x i8> @umax_xor_Cpow2(<2 x i8> %x) {
define i8 @umin_xor_Cpow2(i8 %x) {
; CHECK-LABEL: @umin_xor_Cpow2(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], 64
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[X:%.*]], -65
; CHECK-NEXT: ret i8 [[R]]
;
%x_xor = xor i8 %x, 64
@@ -37,8 +35,7 @@ define i8 @umin_xor_Cpow2(i8 %x) {
define i8 @smax_xor_Cpow2_pos(i8 %x) {
; CHECK-LABEL: @smax_xor_Cpow2_pos(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], 32
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = or i8 [[X:%.*]], 32
; CHECK-NEXT: ret i8 [[R]]
;
%x_xor = xor i8 %x, 32
@@ -48,8 +45,7 @@ define i8 @smax_xor_Cpow2_pos(i8 %x) {
define <2 x i8> @smin_xor_Cpow2_pos(<2 x i8> %x) {
; CHECK-LABEL: @smin_xor_Cpow2_pos(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 16, i8 16>
-; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.smin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[X:%.*]], <i8 -17, i8 -17>
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%x_xor = xor <2 x i8> %x, <i8 16, i8 16>
@@ -59,8 +55,7 @@ define <2 x i8> @smin_xor_Cpow2_pos(<2 x i8> %x) {
define <2 x i8> @smax_xor_Cpow2_neg(<2 x i8> %x) {
; CHECK-LABEL: @smax_xor_Cpow2_neg(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -128, i8 -128>
-; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.smax.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[X:%.*]], <i8 127, i8 127>
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%x_xor = xor <2 x i8> %x, <i8 128, i8 128>
@@ -70,8 +65,7 @@ define <2 x i8> @smax_xor_Cpow2_neg(<2 x i8> %x) {
define i8 @smin_xor_Cpow2_neg(i8 %x) {
; CHECK-LABEL: @smin_xor_Cpow2_neg(
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[X:%.*]], -128
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = or i8 [[X:%.*]], -128
; CHECK-NEXT: ret i8 [[R]]
;
%x_xor = xor i8 %x, 128
@@ -83,8 +77,7 @@ define i8 @umax_xor_pow2(i8 %x, i8 %y) {
; CHECK-LABEL: @umax_xor_pow2(
; CHECK-NEXT: [[NY:%.*]] = sub i8 0, [[Y:%.*]]
; CHECK-NEXT: [[YP2:%.*]] = and i8 [[NY]], [[Y]]
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.umax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = or i8 [[YP2]], [[X:%.*]]
; CHECK-NEXT: ret i8 [[R]]
;
%ny = sub i8 0, %y
@@ -98,8 +91,8 @@ define <2 x i8> @umin_xor_pow2(<2 x i8> %x, <2 x i8> %y) {
; CHECK-LABEL: @umin_xor_pow2(
; CHECK-NEXT: [[NY:%.*]] = sub <2 x i8> zeroinitializer, [[Y:%.*]]
; CHECK-NEXT: [[YP2:%.*]] = and <2 x i8> [[NY]], [[Y]]
-; CHECK-NEXT: [[X_XOR:%.*]] = xor <2 x i8> [[YP2]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.umin.v2i8(<2 x i8> [[X]], <2 x i8> [[X_XOR]])
+; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i8> [[YP2]], <i8 -1, i8 -1>
+; CHECK-NEXT: [[R:%.*]] = and <2 x i8> [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret <2 x i8> [[R]]
;
%ny = sub <2 x i8> <i8 0, i8 0>, %y
@@ -146,8 +139,7 @@ define i8 @smax_xor_pow2_neg(i8 %x, i8 %y) {
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[YP2]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]]
; CHECK: neg:
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smax.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[R:%.*]] = and i8 [[X:%.*]], 127
; CHECK-NEXT: ret i8 [[R]]
; CHECK: pos:
; CHECK-NEXT: call void @barrier()
@@ -173,8 +165,8 @@ define i8 @smin_xor_pow2_pos(i8 %x, i8 %y) {
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[YP2]], 0
; CHECK-NEXT: br i1 [[CMP]], label [[NEG:%.*]], label [[POS:%.*]]
; CHECK: neg:
-; CHECK-NEXT: [[X_XOR:%.*]] = xor i8 [[YP2]], [[X:%.*]]
-; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 [[X_XOR]])
+; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[YP2]], -1
+; CHECK-NEXT: [[R:%.*]] = and i8 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i8 [[R]]
; CHECK: pos:
; CHECK-NEXT: call void @barrier()
More information about the llvm-commits
mailing list