[llvm] 0965272 - [EarlyCSE] fold commutable intrinsics
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Sat Aug 29 09:11:35 PDT 2020
Author: Sanjay Patel
Date: 2020-08-29T12:11:01-04:00
New Revision: 096527214033772e8d80fdefd8a018b9bfa20021
URL: https://github.com/llvm/llvm-project/commit/096527214033772e8d80fdefd8a018b9bfa20021
DIFF: https://github.com/llvm/llvm-project/commit/096527214033772e8d80fdefd8a018b9bfa20021.diff
LOG: [EarlyCSE] fold commutable intrinsics
Handling the new min/max intrinsics is the motivation, but it
turns out that we have a bunch of other intrinsics with this
missing bit of analysis too.
The FP min/max tests show that we are intersecting FMF,
so that part should be safe too.
As noted in https://llvm.org/PR46897 , there is a commutative
property specifier for intrinsics, but no corresponding function
attribute, and so apparently no uses of that bit. We may want to
remove that next.
Follow-up patches should wire up the Instruction::isCommutative()
to this IntrinsicInst specialization. That requires updating
callers to be aware of the more general commutative property
(not just binops).
Differential Revision: https://reviews.llvm.org/D86798
Added:
Modified:
llvm/include/llvm/IR/IntrinsicInst.h
llvm/lib/Transforms/Scalar/EarlyCSE.cpp
llvm/test/Transforms/EarlyCSE/commute.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 7a8898464e66..01ea5dcb8140 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -52,6 +52,37 @@ class IntrinsicInst : public CallInst {
return getCalledFunction()->getIntrinsicID();
}
+ /// Return true if swapping the first two arguments to the intrinsic produces
+ /// the same result.
+ bool isCommutative() {
+ switch (getIntrinsicID()) {
+ case Intrinsic::maxnum:
+ case Intrinsic::minnum:
+ case Intrinsic::maximum:
+ case Intrinsic::minimum:
+ case Intrinsic::smax:
+ case Intrinsic::smin:
+ case Intrinsic::umax:
+ case Intrinsic::umin:
+ case Intrinsic::sadd_sat:
+ case Intrinsic::uadd_sat:
+ case Intrinsic::sadd_with_overflow:
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::smul_with_overflow:
+ case Intrinsic::umul_with_overflow:
+ // TODO: These fixed-point math intrinsics have commutative first two
+ // operands, but callers may not handle instructions with more than
+ // two operands.
+ // case Intrinsic::smul_fix:
+ // case Intrinsic::umul_fix:
+ // case Intrinsic::smul_fix_sat:
+ // case Intrinsic::umul_fix_sat:
+ return true;
+ default:
+ return false;
+ }
+ }
+
// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const CallInst *I) {
if (const Function *CF = I->getCalledFunction())
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index ddfc8555b0a0..51da10fc4879 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -288,6 +288,17 @@ static unsigned getHashValueImpl(SimpleValue Val) {
isa<FreezeInst>(Inst)) &&
"Invalid/unknown instruction");
+ // Handle intrinsics with commutative operands.
+ // TODO: Extend this to handle intrinsics with >2 operands where the 1st
+ // 2 operands are commutative.
+ auto *II = dyn_cast<IntrinsicInst>(Inst);
+ if (II && II->isCommutative() && II->getNumArgOperands() == 2) {
+ Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ return hash_combine(II->getOpcode(), LHS, RHS);
+ }
+
// Mix in the opcode.
return hash_combine(
Inst->getOpcode(),
@@ -340,6 +351,15 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
}
+ // TODO: Extend this for >2 args by matching the trailing N-2 args.
+ auto *LII = dyn_cast<IntrinsicInst>(LHSI);
+ auto *RII = dyn_cast<IntrinsicInst>(RHSI);
+ if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() &&
+ LII->isCommutative() && LII->getNumArgOperands() == 2) {
+ return LII->getArgOperand(0) == RII->getArgOperand(1) &&
+ LII->getArgOperand(1) == RII->getArgOperand(0);
+ }
+
// Min/max/abs can occur with commuted operands, non-canonical predicates,
// and/or non-canonical operands.
// Selects can be non-trivially equivalent via inverted conditions and swaps.
diff --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll
index 41cd3832fc9d..abecc3903a6f 100644
--- a/llvm/test/Transforms/EarlyCSE/commute.ll
+++ b/llvm/test/Transforms/EarlyCSE/commute.ll
@@ -766,9 +766,7 @@ define i32 @PR41083_2(i32 %p) {
define float @maxnum(float %a, float %b) {
; CHECK-LABEL: @maxnum(
; CHECK-NEXT: [[X:%.*]] = call float @llvm.maxnum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call float @llvm.maxnum.f32(float [[B]], float [[A]])
-; CHECK-NEXT: [[R:%.*]] = fdiv nnan float [[X]], [[Y]]
-; CHECK-NEXT: ret float [[R]]
+; CHECK-NEXT: ret float 1.000000e+00
;
%x = call float @llvm.maxnum.f32(float %a, float %b)
%y = call float @llvm.maxnum.f32(float %b, float %a)
@@ -779,9 +777,7 @@ define float @maxnum(float %a, float %b) {
define <2 x float> @minnum(<2 x float> %a, <2 x float> %b) {
; CHECK-LABEL: @minnum(
; CHECK-NEXT: [[X:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[A:%.*]], <2 x float> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[B]], <2 x float> [[A]])
-; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x float> [[X]], [[Y]]
-; CHECK-NEXT: ret <2 x float> [[R]]
+; CHECK-NEXT: ret <2 x float> <float 1.000000e+00, float 1.000000e+00>
;
%x = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %a, <2 x float> %b)
%y = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %b, <2 x float> %a)
@@ -791,10 +787,8 @@ define <2 x float> @minnum(<2 x float> %a, <2 x float> %b) {
define <2 x double> @maximum(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @maximum(
-; CHECK-NEXT: [[X:%.*]] = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[B]], <2 x double> [[A]])
-; CHECK-NEXT: [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
-; CHECK-NEXT: ret <2 x double> [[R]]
+; CHECK-NEXT: [[X:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]])
+; CHECK-NEXT: ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
;
%x = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> %a, <2 x double> %b)
%y = call <2 x double> @llvm.maximum.v2f64(<2 x double> %b, <2 x double> %a)
@@ -804,10 +798,8 @@ define <2 x double> @maximum(<2 x double> %a, <2 x double> %b) {
define double @minimum(double %a, double %b) {
; CHECK-LABEL: @minimum(
-; CHECK-NEXT: [[X:%.*]] = call nsz double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call ninf double @llvm.minimum.f64(double [[B]], double [[A]])
-; CHECK-NEXT: [[R:%.*]] = fdiv nnan double [[X]], [[Y]]
-; CHECK-NEXT: ret double [[R]]
+; CHECK-NEXT: [[X:%.*]] = call double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]])
+; CHECK-NEXT: ret double 1.000000e+00
;
%x = call nsz double @llvm.minimum.f64(double %a, double %b)
%y = call ninf double @llvm.minimum.f64(double %b, double %a)
@@ -817,11 +809,8 @@ define double @minimum(double %a, double %b) {
define i16 @sadd_ov(i16 %a, i16 %b) {
; CHECK-LABEL: @sadd_ov(
; CHECK-NEXT: [[X:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[A:%.*]], i16 [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[B]], i16 [[A]])
; CHECK-NEXT: [[X1:%.*]] = extractvalue { i16, i1 } [[X]], 0
-; CHECK-NEXT: [[Y1:%.*]] = extractvalue { i16, i1 } [[Y]], 0
-; CHECK-NEXT: [[O:%.*]] = or i16 [[X1]], [[Y1]]
-; CHECK-NEXT: ret i16 [[O]]
+; CHECK-NEXT: ret i16 [[X1]]
;
%x = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b)
%y = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %b, i16 %a)
@@ -834,11 +823,8 @@ define i16 @sadd_ov(i16 %a, i16 %b) {
define <5 x i65> @uadd_ov(<5 x i65> %a, <5 x i65> %b) {
; CHECK-LABEL: @uadd_ov(
; CHECK-NEXT: [[X:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[A:%.*]], <5 x i65> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[B]], <5 x i65> [[A]])
; CHECK-NEXT: [[X1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[X]], 0
-; CHECK-NEXT: [[Y1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[Y]], 0
-; CHECK-NEXT: [[O:%.*]] = or <5 x i65> [[X1]], [[Y1]]
-; CHECK-NEXT: ret <5 x i65> [[O]]
+; CHECK-NEXT: ret <5 x i65> [[X1]]
;
%x = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %a, <5 x i65> %b)
%y = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %b, <5 x i65> %a)
@@ -851,11 +837,8 @@ define <5 x i65> @uadd_ov(<5 x i65> %a, <5 x i65> %b) {
define i37 @smul_ov(i37 %a, i37 %b) {
; CHECK-LABEL: @smul_ov(
; CHECK-NEXT: [[X:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[A:%.*]], i37 [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[B]], i37 [[A]])
; CHECK-NEXT: [[X1:%.*]] = extractvalue { i37, i1 } [[X]], 0
-; CHECK-NEXT: [[Y1:%.*]] = extractvalue { i37, i1 } [[Y]], 0
-; CHECK-NEXT: [[O:%.*]] = or i37 [[X1]], [[Y1]]
-; CHECK-NEXT: ret i37 [[O]]
+; CHECK-NEXT: ret i37 [[X1]]
;
%x = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %a, i37 %b)
%y = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %b, i37 %a)
@@ -868,11 +851,8 @@ define i37 @smul_ov(i37 %a, i37 %b) {
define <2 x i31> @umul_ov(<2 x i31> %a, <2 x i31> %b) {
; CHECK-LABEL: @umul_ov(
; CHECK-NEXT: [[X:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[A:%.*]], <2 x i31> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[B]], <2 x i31> [[A]])
; CHECK-NEXT: [[X1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[X]], 0
-; CHECK-NEXT: [[Y1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[Y]], 0
-; CHECK-NEXT: [[O:%.*]] = or <2 x i31> [[X1]], [[Y1]]
-; CHECK-NEXT: ret <2 x i31> [[O]]
+; CHECK-NEXT: ret <2 x i31> [[X1]]
;
%x = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %a, <2 x i31> %b)
%y = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %b, <2 x i31> %a)
@@ -885,9 +865,7 @@ define <2 x i31> @umul_ov(<2 x i31> %a, <2 x i31> %b) {
define i64 @sadd_sat(i64 %a, i64 %b) {
; CHECK-LABEL: @sadd_sat(
; CHECK-NEXT: [[X:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[A:%.*]], i64 [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[B]], i64 [[A]])
-; CHECK-NEXT: [[O:%.*]] = or i64 [[X]], [[Y]]
-; CHECK-NEXT: ret i64 [[O]]
+; CHECK-NEXT: ret i64 [[X]]
;
%x = call i64 @llvm.sadd.sat.i64(i64 %a, i64 %b)
%y = call i64 @llvm.sadd.sat.i64(i64 %b, i64 %a)
@@ -898,9 +876,7 @@ define i64 @sadd_sat(i64 %a, i64 %b) {
define <2 x i64> @uadd_sat(<2 x i64> %a, <2 x i64> %b) {
; CHECK-LABEL: @uadd_sat(
; CHECK-NEXT: [[X:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[B]], <2 x i64> [[A]])
-; CHECK-NEXT: [[O:%.*]] = or <2 x i64> [[X]], [[Y]]
-; CHECK-NEXT: ret <2 x i64> [[O]]
+; CHECK-NEXT: ret <2 x i64> [[X]]
;
%x = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %a, <2 x i64> %b)
%y = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %b, <2 x i64> %a)
@@ -911,9 +887,7 @@ define <2 x i64> @uadd_sat(<2 x i64> %a, <2 x i64> %b) {
define <2 x i64> @smax(<2 x i64> %a, <2 x i64> %b) {
; CHECK-LABEL: @smax(
; CHECK-NEXT: [[X:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[B]], <2 x i64> [[A]])
-; CHECK-NEXT: [[O:%.*]] = or <2 x i64> [[X]], [[Y]]
-; CHECK-NEXT: ret <2 x i64> [[O]]
+; CHECK-NEXT: ret <2 x i64> [[X]]
;
%x = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %a, <2 x i64> %b)
%y = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %b, <2 x i64> %a)
@@ -924,9 +898,7 @@ define <2 x i64> @smax(<2 x i64> %a, <2 x i64> %b) {
define i4 @smin(i4 %a, i4 %b) {
; CHECK-LABEL: @smin(
; CHECK-NEXT: [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call i4 @llvm.smin.i4(i4 [[B]], i4 [[A]])
-; CHECK-NEXT: [[O:%.*]] = or i4 [[X]], [[Y]]
-; CHECK-NEXT: ret i4 [[O]]
+; CHECK-NEXT: ret i4 [[X]]
;
%x = call i4 @llvm.smin.i4(i4 %a, i4 %b)
%y = call i4 @llvm.smin.i4(i4 %b, i4 %a)
@@ -937,9 +909,7 @@ define i4 @smin(i4 %a, i4 %b) {
define i67 @umax(i67 %a, i67 %b) {
; CHECK-LABEL: @umax(
; CHECK-NEXT: [[X:%.*]] = call i67 @llvm.umax.i67(i67 [[A:%.*]], i67 [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call i67 @llvm.umax.i67(i67 [[B]], i67 [[A]])
-; CHECK-NEXT: [[O:%.*]] = or i67 [[X]], [[Y]]
-; CHECK-NEXT: ret i67 [[O]]
+; CHECK-NEXT: ret i67 [[X]]
;
%x = call i67 @llvm.umax.i67(i67 %a, i67 %b)
%y = call i67 @llvm.umax.i67(i67 %b, i67 %a)
@@ -950,9 +920,7 @@ define i67 @umax(i67 %a, i67 %b) {
define <3 x i17> @umin(<3 x i17> %a, <3 x i17> %b) {
; CHECK-LABEL: @umin(
; CHECK-NEXT: [[X:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[A:%.*]], <3 x i17> [[B:%.*]])
-; CHECK-NEXT: [[Y:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[B]], <3 x i17> [[A]])
-; CHECK-NEXT: [[O:%.*]] = or <3 x i17> [[X]], [[Y]]
-; CHECK-NEXT: ret <3 x i17> [[O]]
+; CHECK-NEXT: ret <3 x i17> [[X]]
;
%x = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %a, <3 x i17> %b)
%y = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %b, <3 x i17> %a)
@@ -960,6 +928,8 @@ define <3 x i17> @umin(<3 x i17> %a, <3 x i17> %b) {
ret <3 x i17> %o
}
+; Negative test - mismatched intrinsics
+
define i4 @smin_umin(i4 %a, i4 %b) {
; CHECK-LABEL: @smin_umin(
; CHECK-NEXT: [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]])
@@ -973,6 +943,8 @@ define i4 @smin_umin(i4 %a, i4 %b) {
ret i4 %o
}
+; TODO: handle >2 args
+
define i16 @smul_fix(i16 %a, i16 %b) {
; CHECK-LABEL: @smul_fix(
; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3)
@@ -986,6 +958,8 @@ define i16 @smul_fix(i16 %a, i16 %b) {
ret i16 %o
}
+; TODO: handle >2 args
+
define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
; CHECK-LABEL: @umul_fix(
; CHECK-NEXT: [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1)
@@ -999,6 +973,8 @@ define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
ret i16 %o
}
+; TODO: handle >2 args
+
define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
; CHECK-LABEL: @smul_fix_sat(
; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2)
@@ -1012,6 +988,8 @@ define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
ret <3 x i16> %o
}
+; TODO: handle >2 args
+
define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
; CHECK-LABEL: @umul_fix_sat(
; CHECK-NEXT: [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3)
More information about the llvm-commits
mailing list