[llvm] 0965272 - [EarlyCSE] fold commutable intrinsics

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 29 09:11:35 PDT 2020


Author: Sanjay Patel
Date: 2020-08-29T12:11:01-04:00
New Revision: 096527214033772e8d80fdefd8a018b9bfa20021

URL: https://github.com/llvm/llvm-project/commit/096527214033772e8d80fdefd8a018b9bfa20021
DIFF: https://github.com/llvm/llvm-project/commit/096527214033772e8d80fdefd8a018b9bfa20021.diff

LOG: [EarlyCSE] fold commutable intrinsics

Handling the new min/max intrinsics is the motivation, but it
turns out that we have a bunch of other intrinsics with this
missing bit of analysis too.

The FP min/max tests show that we are intersecting FMF,
so that part should be safe too.

As noted in https://llvm.org/PR46897 , there is a commutative
property specifier for intrinsics, but no corresponding function
attribute, and so apparently no uses of that bit. We may want to
remove that next.

Follow-up patches should wire up the Instruction::isCommutative()
to this IntrinsicInst specialization. That requires updating
callers to be aware of the more general commutative property
(not just binops).

Differential Revision: https://reviews.llvm.org/D86798

Added: 
    

Modified: 
    llvm/include/llvm/IR/IntrinsicInst.h
    llvm/lib/Transforms/Scalar/EarlyCSE.cpp
    llvm/test/Transforms/EarlyCSE/commute.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 7a8898464e66..01ea5dcb8140 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -52,6 +52,37 @@ class IntrinsicInst : public CallInst {
     return getCalledFunction()->getIntrinsicID();
   }
 
+  /// Return true if swapping the first two arguments to the intrinsic produces
+  /// the same result.
+  bool isCommutative() {
+    switch (getIntrinsicID()) {
+    case Intrinsic::maxnum:
+    case Intrinsic::minnum:
+    case Intrinsic::maximum:
+    case Intrinsic::minimum:
+    case Intrinsic::smax:
+    case Intrinsic::smin:
+    case Intrinsic::umax:
+    case Intrinsic::umin:
+    case Intrinsic::sadd_sat:
+    case Intrinsic::uadd_sat:
+    case Intrinsic::sadd_with_overflow:
+    case Intrinsic::uadd_with_overflow:
+    case Intrinsic::smul_with_overflow:
+    case Intrinsic::umul_with_overflow:
+    // TODO: These fixed-point math intrinsics have commutative first two
+    //       operands, but callers may not handle instructions with more than
+    //       two operands.
+    // case Intrinsic::smul_fix:
+    // case Intrinsic::umul_fix:
+    // case Intrinsic::smul_fix_sat:
+    // case Intrinsic::umul_fix_sat:
+      return true;
+    default:
+      return false;
+    }
+  }
+
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const CallInst *I) {
     if (const Function *CF = I->getCalledFunction())

diff  --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index ddfc8555b0a0..51da10fc4879 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -288,6 +288,17 @@ static unsigned getHashValueImpl(SimpleValue Val) {
           isa<FreezeInst>(Inst)) &&
          "Invalid/unknown instruction");
 
+  // Handle intrinsics with commutative operands.
+  // TODO: Extend this to handle intrinsics with >2 operands where the 1st
+  //       2 operands are commutative.
+  auto *II = dyn_cast<IntrinsicInst>(Inst);
+  if (II && II->isCommutative() && II->getNumArgOperands() == 2) {
+    Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
+    if (LHS > RHS)
+      std::swap(LHS, RHS);
+    return hash_combine(II->getOpcode(), LHS, RHS);
+  }
+
   // Mix in the opcode.
   return hash_combine(
       Inst->getOpcode(),
@@ -340,6 +351,15 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
            LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
   }
 
+  // TODO: Extend this for >2 args by matching the trailing N-2 args.
+  auto *LII = dyn_cast<IntrinsicInst>(LHSI);
+  auto *RII = dyn_cast<IntrinsicInst>(RHSI);
+  if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() &&
+      LII->isCommutative() && LII->getNumArgOperands() == 2) {
+    return LII->getArgOperand(0) == RII->getArgOperand(1) &&
+           LII->getArgOperand(1) == RII->getArgOperand(0);
+  }
+
   // Min/max/abs can occur with commuted operands, non-canonical predicates,
   // and/or non-canonical operands.
   // Selects can be non-trivially equivalent via inverted conditions and swaps.

diff  --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll
index 41cd3832fc9d..abecc3903a6f 100644
--- a/llvm/test/Transforms/EarlyCSE/commute.ll
+++ b/llvm/test/Transforms/EarlyCSE/commute.ll
@@ -766,9 +766,7 @@ define i32 @PR41083_2(i32 %p) {
 define float @maxnum(float %a, float %b) {
 ; CHECK-LABEL: @maxnum(
 ; CHECK-NEXT:    [[X:%.*]] = call float @llvm.maxnum.f32(float [[A:%.*]], float [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call float @llvm.maxnum.f32(float [[B]], float [[A]])
-; CHECK-NEXT:    [[R:%.*]] = fdiv nnan float [[X]], [[Y]]
-; CHECK-NEXT:    ret float [[R]]
+; CHECK-NEXT:    ret float 1.000000e+00
 ;
   %x = call float @llvm.maxnum.f32(float %a, float %b)
   %y = call float @llvm.maxnum.f32(float %b, float %a)
@@ -779,9 +777,7 @@ define float @maxnum(float %a, float %b) {
 define <2 x float> @minnum(<2 x float> %a, <2 x float> %b) {
 ; CHECK-LABEL: @minnum(
 ; CHECK-NEXT:    [[X:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[A:%.*]], <2 x float> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> [[B]], <2 x float> [[A]])
-; CHECK-NEXT:    [[R:%.*]] = fdiv nnan <2 x float> [[X]], [[Y]]
-; CHECK-NEXT:    ret <2 x float> [[R]]
+; CHECK-NEXT:    ret <2 x float> <float 1.000000e+00, float 1.000000e+00>
 ;
   %x = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %a, <2 x float> %b)
   %y = call fast <2 x float> @llvm.minnum.v2f32(<2 x float> %b, <2 x float> %a)
@@ -791,10 +787,8 @@ define <2 x float> @minnum(<2 x float> %a, <2 x float> %b) {
 
 define <2 x double> @maximum(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @maximum(
-; CHECK-NEXT:    [[X:%.*]] = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[B]], <2 x double> [[A]])
-; CHECK-NEXT:    [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
-; CHECK-NEXT:    ret <2 x double> [[R]]
+; CHECK-NEXT:    [[X:%.*]] = call <2 x double> @llvm.maximum.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]])
+; CHECK-NEXT:    ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
 ;
   %x = call fast <2 x double> @llvm.maximum.v2f64(<2 x double> %a, <2 x double> %b)
   %y = call <2 x double> @llvm.maximum.v2f64(<2 x double> %b, <2 x double> %a)
@@ -804,10 +798,8 @@ define <2 x double> @maximum(<2 x double> %a, <2 x double> %b) {
 
 define double @minimum(double %a, double %b) {
 ; CHECK-LABEL: @minimum(
-; CHECK-NEXT:    [[X:%.*]] = call nsz double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call ninf double @llvm.minimum.f64(double [[B]], double [[A]])
-; CHECK-NEXT:    [[R:%.*]] = fdiv nnan double [[X]], [[Y]]
-; CHECK-NEXT:    ret double [[R]]
+; CHECK-NEXT:    [[X:%.*]] = call double @llvm.minimum.f64(double [[A:%.*]], double [[B:%.*]])
+; CHECK-NEXT:    ret double 1.000000e+00
 ;
   %x = call nsz double @llvm.minimum.f64(double %a, double %b)
   %y = call ninf double @llvm.minimum.f64(double %b, double %a)
@@ -817,11 +809,8 @@ define double @minimum(double %a, double %b) {
 define i16 @sadd_ov(i16 %a, i16 %b) {
 ; CHECK-LABEL: @sadd_ov(
 ; CHECK-NEXT:    [[X:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[A:%.*]], i16 [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call { i16, i1 } @llvm.sadd.with.overflow.i16(i16 [[B]], i16 [[A]])
 ; CHECK-NEXT:    [[X1:%.*]] = extractvalue { i16, i1 } [[X]], 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractvalue { i16, i1 } [[Y]], 0
-; CHECK-NEXT:    [[O:%.*]] = or i16 [[X1]], [[Y1]]
-; CHECK-NEXT:    ret i16 [[O]]
+; CHECK-NEXT:    ret i16 [[X1]]
 ;
   %x = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b)
   %y = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %b, i16 %a)
@@ -834,11 +823,8 @@ define i16 @sadd_ov(i16 %a, i16 %b) {
 define <5 x i65> @uadd_ov(<5 x i65> %a, <5 x i65> %b) {
 ; CHECK-LABEL: @uadd_ov(
 ; CHECK-NEXT:    [[X:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[A:%.*]], <5 x i65> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call { <5 x i65>, <5 x i1> } @llvm.uadd.with.overflow.v5i65(<5 x i65> [[B]], <5 x i65> [[A]])
 ; CHECK-NEXT:    [[X1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[X]], 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractvalue { <5 x i65>, <5 x i1> } [[Y]], 0
-; CHECK-NEXT:    [[O:%.*]] = or <5 x i65> [[X1]], [[Y1]]
-; CHECK-NEXT:    ret <5 x i65> [[O]]
+; CHECK-NEXT:    ret <5 x i65> [[X1]]
 ;
   %x = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %a, <5 x i65> %b)
   %y = call {<5 x i65>, <5 x i1>} @llvm.uadd.with.overflow.v5i65(<5 x i65> %b, <5 x i65> %a)
@@ -851,11 +837,8 @@ define <5 x i65> @uadd_ov(<5 x i65> %a, <5 x i65> %b) {
 define i37 @smul_ov(i37 %a, i37 %b) {
 ; CHECK-LABEL: @smul_ov(
 ; CHECK-NEXT:    [[X:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[A:%.*]], i37 [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call { i37, i1 } @llvm.smul.with.overflow.i37(i37 [[B]], i37 [[A]])
 ; CHECK-NEXT:    [[X1:%.*]] = extractvalue { i37, i1 } [[X]], 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractvalue { i37, i1 } [[Y]], 0
-; CHECK-NEXT:    [[O:%.*]] = or i37 [[X1]], [[Y1]]
-; CHECK-NEXT:    ret i37 [[O]]
+; CHECK-NEXT:    ret i37 [[X1]]
 ;
   %x = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %a, i37 %b)
   %y = call {i37, i1} @llvm.smul.with.overflow.i37(i37 %b, i37 %a)
@@ -868,11 +851,8 @@ define i37 @smul_ov(i37 %a, i37 %b) {
 define <2 x i31> @umul_ov(<2 x i31> %a, <2 x i31> %b) {
 ; CHECK-LABEL: @umul_ov(
 ; CHECK-NEXT:    [[X:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[A:%.*]], <2 x i31> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call { <2 x i31>, <2 x i1> } @llvm.umul.with.overflow.v2i31(<2 x i31> [[B]], <2 x i31> [[A]])
 ; CHECK-NEXT:    [[X1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[X]], 0
-; CHECK-NEXT:    [[Y1:%.*]] = extractvalue { <2 x i31>, <2 x i1> } [[Y]], 0
-; CHECK-NEXT:    [[O:%.*]] = or <2 x i31> [[X1]], [[Y1]]
-; CHECK-NEXT:    ret <2 x i31> [[O]]
+; CHECK-NEXT:    ret <2 x i31> [[X1]]
 ;
   %x = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %a, <2 x i31> %b)
   %y = call {<2 x i31>, <2 x i1>} @llvm.umul.with.overflow.v2i31(<2 x i31> %b, <2 x i31> %a)
@@ -885,9 +865,7 @@ define <2 x i31> @umul_ov(<2 x i31> %a, <2 x i31> %b) {
 define i64 @sadd_sat(i64 %a, i64 %b) {
 ; CHECK-LABEL: @sadd_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[A:%.*]], i64 [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call i64 @llvm.sadd.sat.i64(i64 [[B]], i64 [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or i64 [[X]], [[Y]]
-; CHECK-NEXT:    ret i64 [[O]]
+; CHECK-NEXT:    ret i64 [[X]]
 ;
   %x = call i64 @llvm.sadd.sat.i64(i64 %a, i64 %b)
   %y = call i64 @llvm.sadd.sat.i64(i64 %b, i64 %a)
@@ -898,9 +876,7 @@ define i64 @sadd_sat(i64 %a, i64 %b) {
 define <2 x i64> @uadd_sat(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-LABEL: @uadd_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> [[B]], <2 x i64> [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or <2 x i64> [[X]], [[Y]]
-; CHECK-NEXT:    ret <2 x i64> [[O]]
+; CHECK-NEXT:    ret <2 x i64> [[X]]
 ;
   %x = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %a, <2 x i64> %b)
   %y = call <2 x i64> @llvm.uadd.sat.v2i64(<2 x i64> %b, <2 x i64> %a)
@@ -911,9 +887,7 @@ define <2 x i64> @uadd_sat(<2 x i64> %a, <2 x i64> %b) {
 define <2 x i64> @smax(<2 x i64> %a, <2 x i64> %b) {
 ; CHECK-LABEL: @smax(
 ; CHECK-NEXT:    [[X:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call <2 x i64> @llvm.smax.v2i64(<2 x i64> [[B]], <2 x i64> [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or <2 x i64> [[X]], [[Y]]
-; CHECK-NEXT:    ret <2 x i64> [[O]]
+; CHECK-NEXT:    ret <2 x i64> [[X]]
 ;
   %x = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %a, <2 x i64> %b)
   %y = call <2 x i64> @llvm.smax.v2i64(<2 x i64> %b, <2 x i64> %a)
@@ -924,9 +898,7 @@ define <2 x i64> @smax(<2 x i64> %a, <2 x i64> %b) {
 define i4 @smin(i4 %a, i4 %b) {
 ; CHECK-LABEL: @smin(
 ; CHECK-NEXT:    [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call i4 @llvm.smin.i4(i4 [[B]], i4 [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or i4 [[X]], [[Y]]
-; CHECK-NEXT:    ret i4 [[O]]
+; CHECK-NEXT:    ret i4 [[X]]
 ;
   %x = call i4 @llvm.smin.i4(i4 %a, i4 %b)
   %y = call i4 @llvm.smin.i4(i4 %b, i4 %a)
@@ -937,9 +909,7 @@ define i4 @smin(i4 %a, i4 %b) {
 define i67 @umax(i67 %a, i67 %b) {
 ; CHECK-LABEL: @umax(
 ; CHECK-NEXT:    [[X:%.*]] = call i67 @llvm.umax.i67(i67 [[A:%.*]], i67 [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call i67 @llvm.umax.i67(i67 [[B]], i67 [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or i67 [[X]], [[Y]]
-; CHECK-NEXT:    ret i67 [[O]]
+; CHECK-NEXT:    ret i67 [[X]]
 ;
   %x = call i67 @llvm.umax.i67(i67 %a, i67 %b)
   %y = call i67 @llvm.umax.i67(i67 %b, i67 %a)
@@ -950,9 +920,7 @@ define i67 @umax(i67 %a, i67 %b) {
 define <3 x i17> @umin(<3 x i17> %a, <3 x i17> %b) {
 ; CHECK-LABEL: @umin(
 ; CHECK-NEXT:    [[X:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[A:%.*]], <3 x i17> [[B:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call <3 x i17> @llvm.umin.v3i17(<3 x i17> [[B]], <3 x i17> [[A]])
-; CHECK-NEXT:    [[O:%.*]] = or <3 x i17> [[X]], [[Y]]
-; CHECK-NEXT:    ret <3 x i17> [[O]]
+; CHECK-NEXT:    ret <3 x i17> [[X]]
 ;
   %x = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %a, <3 x i17> %b)
   %y = call <3 x i17> @llvm.umin.v3i17(<3 x i17> %b, <3 x i17> %a)
@@ -960,6 +928,8 @@ define <3 x i17> @umin(<3 x i17> %a, <3 x i17> %b) {
   ret <3 x i17> %o
 }
 
+; Negative test - mismatched intrinsics
+
 define i4 @smin_umin(i4 %a, i4 %b) {
 ; CHECK-LABEL: @smin_umin(
 ; CHECK-NEXT:    [[X:%.*]] = call i4 @llvm.smin.i4(i4 [[A:%.*]], i4 [[B:%.*]])
@@ -973,6 +943,8 @@ define i4 @smin_umin(i4 %a, i4 %b) {
   ret i4 %o
 }
 
+; TODO: handle >2 args
+
 define i16 @smul_fix(i16 %a, i16 %b) {
 ; CHECK-LABEL: @smul_fix(
 ; CHECK-NEXT:    [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3)
@@ -986,6 +958,8 @@ define i16 @smul_fix(i16 %a, i16 %b) {
   ret i16 %o
 }
 
+; TODO: handle >2 args
+
 define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
 ; CHECK-LABEL: @umul_fix(
 ; CHECK-NEXT:    [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1)
@@ -999,6 +973,8 @@ define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
   ret i16 %o
 }
 
+; TODO: handle >2 args
+
 define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
 ; CHECK-LABEL: @smul_fix_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2)
@@ -1012,6 +988,8 @@ define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
   ret <3 x i16> %o
 }
 
+; TODO: handle >2 args
+
 define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
 ; CHECK-LABEL: @umul_fix_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3)


        


More information about the llvm-commits mailing list