[llvm] d7fe2cf - [InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. (#118932)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 17:02:14 PST 2024


Author: tianleliu
Date: 2024-12-18T09:02:11+08:00
New Revision: d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1

URL: https://github.com/llvm/llvm-project/commit/d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1
DIFF: https://github.com/llvm/llvm-project/commit/d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1.diff

LOG: [InstCombine]  Widen Sel width after Cmp to generate Max/Min intrinsics. (#118932)

When Sel(Cmp) are in different integer type,

From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN

Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD
instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
                           https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
                           https://gcc.godbolt.org/z/o97svGvYM
                           https://gcc.godbolt.org/z/59Ynj91ov

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Transforms/InstCombine/minmax-fold.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a43f5b6cec2f4e..14d7c2da8a9f8e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8803,40 +8803,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
   return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
 }
 
-/// Helps to match a select pattern in case of a type mismatch.
-///
-/// The function processes the case when type of true and false values of a
-/// select instruction 
diff ers from type of the cmp instruction operands because
-/// of a cast instruction. The function checks if it is legal to move the cast
-/// operation after "select". If yes, it returns the new second value of
-/// "select" (with the assumption that cast is moved):
-/// 1. As operand of cast instruction when both values of "select" are same cast
-/// instructions.
-/// 2. As restored constant (by applying reverse cast operation) when the first
-/// value of the "select" is a cast operation and the second value is a
-/// constant.
-/// NOTE: We return only the new second value because the first value could be
-/// accessed as operand of cast instruction.
-static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
-                              Instruction::CastOps *CastOp) {
-  auto *Cast1 = dyn_cast<CastInst>(V1);
-  if (!Cast1)
-    return nullptr;
-
-  *CastOp = Cast1->getOpcode();
-  Type *SrcTy = Cast1->getSrcTy();
-  if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
-    // If V1 and V2 are both the same cast from the same type, look through V1.
-    if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
-      return Cast2->getOperand(0);
-    return nullptr;
-  }
-
-  auto *C = dyn_cast<Constant>(V2);
-  if (!C)
-    return nullptr;
-
+static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
+                                   Instruction::CastOps *CastOp) {
   const DataLayout &DL = CmpI->getDataLayout();
+
   Constant *CastedTo = nullptr;
   switch (*CastOp) {
   case Instruction::ZExt:
@@ -8912,6 +8882,63 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
   return CastedTo;
 }
 
+/// Helps to match a select pattern in case of a type mismatch.
+///
+/// The function processes the case when type of true and false values of a
+/// select instruction 
diff ers from type of the cmp instruction operands because
+/// of a cast instruction. The function checks if it is legal to move the cast
+/// operation after "select". If yes, it returns the new second value of
+/// "select" (with the assumption that cast is moved):
+/// 1. As operand of cast instruction when both values of "select" are same cast
+/// instructions.
+/// 2. As restored constant (by applying reverse cast operation) when the first
+/// value of the "select" is a cast operation and the second value is a
+/// constant. It is implemented in lookThroughCastConst().
+/// 3. As one operand is cast instruction and the other is not. The operands in
+/// sel(cmp) are in 
diff erent type integer.
+/// NOTE: We return only the new second value because the first value could be
+/// accessed as operand of cast instruction.
+static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
+                              Instruction::CastOps *CastOp) {
+  auto *Cast1 = dyn_cast<CastInst>(V1);
+  if (!Cast1)
+    return nullptr;
+
+  *CastOp = Cast1->getOpcode();
+  Type *SrcTy = Cast1->getSrcTy();
+  if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
+    // If V1 and V2 are both the same cast from the same type, look through V1.
+    if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
+      return Cast2->getOperand(0);
+    return nullptr;
+  }
+
+  auto *C = dyn_cast<Constant>(V2);
+  if (C)
+    return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
+
+  Value *CastedTo = nullptr;
+  if (*CastOp == Instruction::Trunc) {
+    if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2)))) {
+      // Here we have the following case:
+      //   %y_ext = sext iK %y to iN
+      //   %cond = cmp iN %x, %y_ext
+      //   %tr = trunc iN %x to iK
+      //   %narrowsel = select i1 %cond, iK %tr, iK %y
+      //
+      // We can always move trunc after select operation:
+      //   %y_ext = sext iK %y to iN
+      //   %cond = cmp iN %x, %y_ext
+      //   %widesel = select i1 %cond, iN %x, iN %y_ext
+      //   %tr = trunc iN %widesel to iK
+      assert(V2->getType() == Cast1->getType() &&
+             "V2 and Cast1 should be the same type.");
+      CastedTo = CmpI->getOperand(1);
+    }
+  }
+
+  return CastedTo;
+}
 SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
                                              Instruction::CastOps *CastOp,
                                              unsigned Depth) {

diff  --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll
index 2e267958d0476e..4d66e261c649cc 100644
--- a/llvm/test/Transforms/InstCombine/minmax-fold.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll
@@ -697,6 +697,34 @@ define zeroext i8 @look_through_cast2(i32 %x) {
   ret i8 %res
 }
 
+define i8 @look_through_cast_int_min(i8 %a, i32 %min) {
+; CHECK-LABEL: @look_through_cast_int_min(
+; CHECK-NEXT:    [[A32:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT:    [[SEL1:%.*]] = call i32 @llvm.smin.i32(i32 [[MIN:%.*]], i32 [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc i32 [[SEL1]] to i8
+; CHECK-NEXT:    ret i8 [[SEL]]
+;
+  %a32 = sext i8 %a to i32
+  %cmp = icmp slt i32 %a32, %min
+  %min8 = trunc i32 %min to i8
+  %sel = select i1 %cmp, i8 %a, i8 %min8
+  ret i8 %sel
+}
+
+define i16 @look_through_cast_int_max(i16 %a, i32 %max) {
+; CHECK-LABEL: @look_through_cast_int_max(
+; CHECK-NEXT:    [[A32:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT:    [[SEL1:%.*]] = call i32 @llvm.smax.i32(i32 [[MAX:%.*]], i32 [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc i32 [[SEL1]] to i16
+; CHECK-NEXT:    ret i16 [[SEL]]
+;
+  %a32 = zext i16 %a to i32
+  %cmp = icmp sgt i32 %max, %a32
+  %max8 = trunc i32 %max to i16
+  %sel = select i1 %cmp, i16 %max8, i16 %a
+  ret i16 %sel
+}
+
 define <2 x i8> @min_through_cast_vec1(<2 x i32> %x) {
 ; CHECK-LABEL: @min_through_cast_vec1(
 ; CHECK-NEXT:    [[RES1:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 510, i32 511>)
@@ -721,6 +749,34 @@ define <2 x i8> @min_through_cast_vec2(<2 x i32> %x) {
   ret <2 x i8> %res
 }
 
+define <8 x i8> @look_through_cast_int_min_vec(<8 x i8> %a, <8 x i32> %min) {
+; CHECK-LABEL: @look_through_cast_int_min_vec(
+; CHECK-NEXT:    [[A32:%.*]] = sext <8 x i8> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT:    [[SEL1:%.*]] = call <8 x i32> @llvm.umin.v8i32(<8 x i32> [[MIN:%.*]], <8 x i32> [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc <8 x i32> [[SEL1]] to <8 x i8>
+; CHECK-NEXT:    ret <8 x i8> [[SEL]]
+;
+  %a32 = sext <8 x i8> %a to <8 x i32>
+  %cmp = icmp ult <8 x i32> %a32, %min
+  %min8 = trunc <8 x i32> %min to <8 x i8>
+  %sel = select <8 x i1> %cmp, <8 x i8> %a, <8 x i8> %min8
+  ret <8 x i8> %sel
+}
+
+define <8 x i32> @look_through_cast_int_max_vec(<8 x i32> %a, <8 x i64> %max) {
+; CHECK-LABEL: @look_through_cast_int_max_vec(
+; CHECK-NEXT:    [[A32:%.*]] = zext <8 x i32> [[A:%.*]] to <8 x i64>
+; CHECK-NEXT:    [[SEL1:%.*]] = call <8 x i64> @llvm.smax.v8i64(<8 x i64> [[MAX:%.*]], <8 x i64> [[A32]])
+; CHECK-NEXT:    [[SEL:%.*]] = trunc <8 x i64> [[SEL1]] to <8 x i32>
+; CHECK-NEXT:    ret <8 x i32> [[SEL]]
+;
+  %a32 = zext <8 x i32> %a to <8 x i64>
+  %cmp = icmp sgt <8 x i64> %a32, %max
+  %max8 = trunc <8 x i64> %max to <8 x i32>
+  %sel = select <8 x i1> %cmp, <8 x i32> %a, <8 x i32> %max8
+  ret <8 x i32> %sel
+}
+
 ; Remove a min/max op in a sequence with a common operand.
 ; PR35717: https://bugs.llvm.org/show_bug.cgi?id=35717
 


        


More information about the llvm-commits mailing list