[llvm] e471cd1 - [EarlyCSE] Support CSE for commutative intrinsics with over 2 args (#67255)

via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 24 06:23:04 PDT 2023


Author: XChy
Date: 2023-09-24T21:23:00+08:00
New Revision: e471cd1d7382b98108f888072cbe016a3afc4558

URL: https://github.com/llvm/llvm-project/commit/e471cd1d7382b98108f888072cbe016a3afc4558
DIFF: https://github.com/llvm/llvm-project/commit/e471cd1d7382b98108f888072cbe016a3afc4558.diff

LOG: [EarlyCSE] Support CSE for commutative intrinsics with over 2 args (#67255)

Extends EarlyCSE to support commutative intrinsics with over 2 args.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/EarlyCSE.cpp
    llvm/test/Transforms/EarlyCSE/commute.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index 4990fa9f8b5ea36..f736d429cb63816 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -314,14 +314,14 @@ static unsigned getHashValueImpl(SimpleValue Val) {
          "Invalid/unknown instruction");
 
   // Handle intrinsics with commutative operands.
-  // TODO: Extend this to handle intrinsics with >2 operands where the 1st
-  //       2 operands are commutative.
   auto *II = dyn_cast<IntrinsicInst>(Inst);
-  if (II && II->isCommutative() && II->arg_size() == 2) {
+  if (II && II->isCommutative() && II->arg_size() >= 2) {
     Value *LHS = II->getArgOperand(0), *RHS = II->getArgOperand(1);
     if (LHS > RHS)
       std::swap(LHS, RHS);
-    return hash_combine(II->getOpcode(), LHS, RHS);
+    return hash_combine(
+        II->getOpcode(), LHS, RHS,
+        hash_combine_range(II->value_op_begin() + 2, II->value_op_end()));
   }
 
   // gc.relocate is 'special' call: its second and third operands are
@@ -396,13 +396,14 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
            LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
   }
 
-  // TODO: Extend this for >2 args by matching the trailing N-2 args.
   auto *LII = dyn_cast<IntrinsicInst>(LHSI);
   auto *RII = dyn_cast<IntrinsicInst>(RHSI);
   if (LII && RII && LII->getIntrinsicID() == RII->getIntrinsicID() &&
-      LII->isCommutative() && LII->arg_size() == 2) {
+      LII->isCommutative() && LII->arg_size() >= 2) {
     return LII->getArgOperand(0) == RII->getArgOperand(1) &&
-           LII->getArgOperand(1) == RII->getArgOperand(0);
+           LII->getArgOperand(1) == RII->getArgOperand(0) &&
+           std::equal(LII->arg_begin() + 2, LII->arg_end(),
+                      RII->arg_begin() + 2, RII->arg_end());
   }
 
   // See comment above in `getHashValue()`.

diff  --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll
index 6aaaf992e441423..1cf7ddda7f0dd7f 100644
--- a/llvm/test/Transforms/EarlyCSE/commute.ll
+++ b/llvm/test/Transforms/EarlyCSE/commute.ll
@@ -999,14 +999,10 @@ define i4 @smin_umin(i4 %a, i4 %b) {
   ret i4 %o
 }
 
-; TODO: handle >2 args
-
 define i16 @smul_fix(i16 %a, i16 %b) {
 ; CHECK-LABEL: @smul_fix(
 ; CHECK-NEXT:    [[X:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 3)
-; CHECK-NEXT:    [[Y:%.*]] = call i16 @llvm.smul.fix.i16(i16 [[B]], i16 [[A]], i32 3)
-; CHECK-NEXT:    [[O:%.*]] = or i16 [[X]], [[Y]]
-; CHECK-NEXT:    ret i16 [[O]]
+; CHECK-NEXT:    ret i16 [[X]]
 ;
   %x = call i16 @llvm.smul.fix.i16(i16 %a, i16 %b, i32 3)
   %y = call i16 @llvm.smul.fix.i16(i16 %b, i16 %a, i32 3)
@@ -1014,14 +1010,10 @@ define i16 @smul_fix(i16 %a, i16 %b) {
   ret i16 %o
 }
 
-; TODO: handle >2 args
-
 define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
 ; CHECK-LABEL: @umul_fix(
 ; CHECK-NEXT:    [[X:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[A:%.*]], i16 [[B:%.*]], i32 1)
-; CHECK-NEXT:    [[Y:%.*]] = call i16 @llvm.umul.fix.i16(i16 [[B]], i16 [[A]], i32 1)
-; CHECK-NEXT:    [[O:%.*]] = or i16 [[X]], [[Y]]
-; CHECK-NEXT:    ret i16 [[O]]
+; CHECK-NEXT:    ret i16 [[X]]
 ;
   %x = call i16 @llvm.umul.fix.i16(i16 %a, i16 %b, i32 1)
   %y = call i16 @llvm.umul.fix.i16(i16 %b, i16 %a, i32 1)
@@ -1029,14 +1021,10 @@ define i16 @umul_fix(i16 %a, i16 %b, i32 %s) {
   ret i16 %o
 }
 
-; TODO: handle >2 args
-
 define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
 ; CHECK-LABEL: @smul_fix_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 2)
-; CHECK-NEXT:    [[Y:%.*]] = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 2)
-; CHECK-NEXT:    [[O:%.*]] = or <3 x i16> [[X]], [[Y]]
-; CHECK-NEXT:    ret <3 x i16> [[O]]
+; CHECK-NEXT:    ret <3 x i16> [[X]]
 ;
   %x = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 2)
   %y = call <3 x i16> @llvm.smul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 2)
@@ -1044,14 +1032,10 @@ define <3 x i16> @smul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
   ret <3 x i16> %o
 }
 
-; TODO: handle >2 args
-
 define <3 x i16> @umul_fix_sat(<3 x i16> %a, <3 x i16> %b) {
 ; CHECK-LABEL: @umul_fix_sat(
 ; CHECK-NEXT:    [[X:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[A:%.*]], <3 x i16> [[B:%.*]], i32 3)
-; CHECK-NEXT:    [[Y:%.*]] = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> [[B]], <3 x i16> [[A]], i32 3)
-; CHECK-NEXT:    [[O:%.*]] = or <3 x i16> [[X]], [[Y]]
-; CHECK-NEXT:    ret <3 x i16> [[O]]
+; CHECK-NEXT:    ret <3 x i16> [[X]]
 ;
   %x = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %a, <3 x i16> %b, i32 3)
   %y = call <3 x i16> @llvm.umul.fix.sat.v3i16(<3 x i16> %b, <3 x i16> %a, i32 3)
@@ -1085,17 +1069,26 @@ define i16 @umul_fix_scale(i16 %a, i16 %b, i32 %s) {
   ret i16 %o
 }
 
-; TODO: handle >2 args
-
 define float @fma(float %a, float %b, float %c) {
 ; CHECK-LABEL: @fma(
 ; CHECK-NEXT:    [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call float @llvm.fma.f32(float [[B]], float [[A]], float [[C]])
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %x = call float @llvm.fma.f32(float %a, float %b, float %c)
+  %y = call float @llvm.fma.f32(float %b, float %a, float %c)
+  %r = fdiv nnan float %x, %y
+  ret float %r
+}
+
+define float @fma_fail(float %a, float %b, float %c) {
+; CHECK-LABEL: @fma_fail(
+; CHECK-NEXT:    [[X:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[C:%.*]])
+; CHECK-NEXT:    [[Y:%.*]] = call float @llvm.fma.f32(float [[A]], float [[C]], float [[B]])
 ; CHECK-NEXT:    [[R:%.*]] = fdiv nnan float [[X]], [[Y]]
 ; CHECK-NEXT:    ret float [[R]]
 ;
   %x = call float @llvm.fma.f32(float %a, float %b, float %c)
-  %y = call float @llvm.fma.f32(float %b, float %a, float %c)
+  %y = call float @llvm.fma.f32(float %a, float %c, float %b)
   %r = fdiv nnan float %x, %y
   ret float %r
 }
@@ -1113,17 +1106,39 @@ define float @fma_
diff erent_add_ops(float %a, float %b, float %c, float %d) {
   ret float %r
 }
 
-; TODO: handle >2 args
-
 define <2 x double> @fmuladd(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
 ; CHECK-LABEL: @fmuladd(
 ; CHECK-NEXT:    [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
-; CHECK-NEXT:    [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[B]], <2 x double> [[A]], <2 x double> [[C]])
+; CHECK-NEXT:    ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
+;
+  %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
+  %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c)
+  %r = fdiv nnan <2 x double> %x, %y
+  ret <2 x double> %r
+}
+
+define <2 x double> @fmuladd_fail1(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: @fmuladd_fail1(
+; CHECK-NEXT:    [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
+; CHECK-NEXT:    [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[C]], <2 x double> [[B]], <2 x double> [[A]])
 ; CHECK-NEXT:    [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
 ; CHECK-NEXT:    ret <2 x double> [[R]]
 ;
   %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
-  %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %b, <2 x double> %a, <2 x double> %c)
+  %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %c, <2 x double> %b, <2 x double> %a)
+  %r = fdiv nnan <2 x double> %x, %y
+  ret <2 x double> %r
+}
+
+define <2 x double> @fmuladd_fail2(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
+; CHECK-LABEL: @fmuladd_fail2(
+; CHECK-NEXT:    [[X:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> [[B:%.*]], <2 x double> [[C:%.*]])
+; CHECK-NEXT:    [[Y:%.*]] = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A]], <2 x double> [[C]], <2 x double> [[B]])
+; CHECK-NEXT:    [[R:%.*]] = fdiv nnan <2 x double> [[X]], [[Y]]
+; CHECK-NEXT:    ret <2 x double> [[R]]
+;
+  %x = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %b, <2 x double> %c)
+  %y = call <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> %c, <2 x double> %b)
   %r = fdiv nnan <2 x double> %x, %y
   ret <2 x double> %r
 }


        


More information about the llvm-commits mailing list