[llvm] d7fe2cf - [InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. (#118932)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 17 17:02:14 PST 2024
Author: tianleliu
Date: 2024-12-18T09:02:11+08:00
New Revision: d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1
URL: https://github.com/llvm/llvm-project/commit/d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1
DIFF: https://github.com/llvm/llvm-project/commit/d7fe2cf8a2854f05812b87faf3ce0da296fc5fe1.diff
LOG: [InstCombine] Widen Sel width after Cmp to generate Max/Min intrinsics. (#118932)
When Sel(Cmp) are in different integer type,
From: (K and N mean width, K < N; a and b are src operands.)
bN = Ext(bK)
cond = Cmp(aN, bN)
aK = Trunc aN
retK = Sel(cond, aK, bK)
To:
bN = Ext(bK)
cond = Cmp(aN, bN)
retN = Sel(cond, aN, bN)
retK = Trunc retN
Though Sel's operands width becomes larger, the benefit
of making type width in Sel the same as Cmp, is for combing
to max/min intrinsics, and also better performance for SIMD
instructions.
References of correctness: https://alive2.llvm.org/ce/z/Y4Kegm
https://alive2.llvm.org/ce/z/qFtjtR
Reference of generated code comparision:
https://gcc.godbolt.org/z/o97svGvYM
https://gcc.godbolt.org/z/59Ynj91ov
Added:
Modified:
llvm/lib/Analysis/ValueTracking.cpp
llvm/test/Transforms/InstCombine/minmax-fold.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a43f5b6cec2f4e..14d7c2da8a9f8e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8803,40 +8803,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
}
-/// Helps to match a select pattern in case of a type mismatch.
-///
-/// The function processes the case when type of true and false values of a
-/// select instruction
diff ers from type of the cmp instruction operands because
-/// of a cast instruction. The function checks if it is legal to move the cast
-/// operation after "select". If yes, it returns the new second value of
-/// "select" (with the assumption that cast is moved):
-/// 1. As operand of cast instruction when both values of "select" are same cast
-/// instructions.
-/// 2. As restored constant (by applying reverse cast operation) when the first
-/// value of the "select" is a cast operation and the second value is a
-/// constant.
-/// NOTE: We return only the new second value because the first value could be
-/// accessed as operand of cast instruction.
-static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
- Instruction::CastOps *CastOp) {
- auto *Cast1 = dyn_cast<CastInst>(V1);
- if (!Cast1)
- return nullptr;
-
- *CastOp = Cast1->getOpcode();
- Type *SrcTy = Cast1->getSrcTy();
- if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
- // If V1 and V2 are both the same cast from the same type, look through V1.
- if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
- return Cast2->getOperand(0);
- return nullptr;
- }
-
- auto *C = dyn_cast<Constant>(V2);
- if (!C)
- return nullptr;
-
+static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
+ Instruction::CastOps *CastOp) {
const DataLayout &DL = CmpI->getDataLayout();
+
Constant *CastedTo = nullptr;
switch (*CastOp) {
case Instruction::ZExt:
@@ -8912,6 +8882,63 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
return CastedTo;
}
+/// Helps to match a select pattern in case of a type mismatch.
+///
+/// The function processes the case when type of true and false values of a
+/// select instruction
diff ers from type of the cmp instruction operands because
+/// of a cast instruction. The function checks if it is legal to move the cast
+/// operation after "select". If yes, it returns the new second value of
+/// "select" (with the assumption that cast is moved):
+/// 1. As operand of cast instruction when both values of "select" are same cast
+/// instructions.
+/// 2. As restored constant (by applying reverse cast operation) when the first
+/// value of the "select" is a cast operation and the second value is a
+/// constant. It is implemented in lookThroughCastConst().
+/// 3. As one operand is cast instruction and the other is not. The operands in
+/// sel(cmp) are in
diff erent type integer.
+/// NOTE: We return only the new second value because the first value could be
+/// accessed as operand of cast instruction.
+static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
+ Instruction::CastOps *CastOp) {
+ auto *Cast1 = dyn_cast<CastInst>(V1);
+ if (!Cast1)
+ return nullptr;
+
+ *CastOp = Cast1->getOpcode();
+ Type *SrcTy = Cast1->getSrcTy();
+ if (auto *Cast2 = dyn_cast<CastInst>(V2)) {
+ // If V1 and V2 are both the same cast from the same type, look through V1.
+ if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
+ return Cast2->getOperand(0);
+ return nullptr;
+ }
+
+ auto *C = dyn_cast<Constant>(V2);
+ if (C)
+ return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
+
+ Value *CastedTo = nullptr;
+ if (*CastOp == Instruction::Trunc) {
+ if (match(CmpI->getOperand(1), m_ZExtOrSExt(m_Specific(V2)))) {
+ // Here we have the following case:
+ // %y_ext = sext iK %y to iN
+ // %cond = cmp iN %x, %y_ext
+ // %tr = trunc iN %x to iK
+ // %narrowsel = select i1 %cond, iK %tr, iK %y
+ //
+ // We can always move trunc after select operation:
+ // %y_ext = sext iK %y to iN
+ // %cond = cmp iN %x, %y_ext
+ // %widesel = select i1 %cond, iN %x, iN %y_ext
+ // %tr = trunc iN %widesel to iK
+ assert(V2->getType() == Cast1->getType() &&
+ "V2 and Cast1 should be the same type.");
+ CastedTo = CmpI->getOperand(1);
+ }
+ }
+
+ return CastedTo;
+}
SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp,
unsigned Depth) {
diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll
index 2e267958d0476e..4d66e261c649cc 100644
--- a/llvm/test/Transforms/InstCombine/minmax-fold.ll
+++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll
@@ -697,6 +697,34 @@ define zeroext i8 @look_through_cast2(i32 %x) {
ret i8 %res
}
+define i8 @look_through_cast_int_min(i8 %a, i32 %min) {
+; CHECK-LABEL: @look_through_cast_int_min(
+; CHECK-NEXT: [[A32:%.*]] = sext i8 [[A:%.*]] to i32
+; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smin.i32(i32 [[MIN:%.*]], i32 [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i8
+; CHECK-NEXT: ret i8 [[SEL]]
+;
+ %a32 = sext i8 %a to i32
+ %cmp = icmp slt i32 %a32, %min
+ %min8 = trunc i32 %min to i8
+ %sel = select i1 %cmp, i8 %a, i8 %min8
+ ret i8 %sel
+}
+
+define i16 @look_through_cast_int_max(i16 %a, i32 %max) {
+; CHECK-LABEL: @look_through_cast_int_max(
+; CHECK-NEXT: [[A32:%.*]] = zext i16 [[A:%.*]] to i32
+; CHECK-NEXT: [[SEL1:%.*]] = call i32 @llvm.smax.i32(i32 [[MAX:%.*]], i32 [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc i32 [[SEL1]] to i16
+; CHECK-NEXT: ret i16 [[SEL]]
+;
+ %a32 = zext i16 %a to i32
+ %cmp = icmp sgt i32 %max, %a32
+ %max8 = trunc i32 %max to i16
+ %sel = select i1 %cmp, i16 %max8, i16 %a
+ ret i16 %sel
+}
+
define <2 x i8> @min_through_cast_vec1(<2 x i32> %x) {
; CHECK-LABEL: @min_through_cast_vec1(
; CHECK-NEXT: [[RES1:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> [[X:%.*]], <2 x i32> <i32 510, i32 511>)
@@ -721,6 +749,34 @@ define <2 x i8> @min_through_cast_vec2(<2 x i32> %x) {
ret <2 x i8> %res
}
+define <8 x i8> @look_through_cast_int_min_vec(<8 x i8> %a, <8 x i32> %min) {
+; CHECK-LABEL: @look_through_cast_int_min_vec(
+; CHECK-NEXT: [[A32:%.*]] = sext <8 x i8> [[A:%.*]] to <8 x i32>
+; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i32> @llvm.umin.v8i32(<8 x i32> [[MIN:%.*]], <8 x i32> [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i32> [[SEL1]] to <8 x i8>
+; CHECK-NEXT: ret <8 x i8> [[SEL]]
+;
+ %a32 = sext <8 x i8> %a to <8 x i32>
+ %cmp = icmp ult <8 x i32> %a32, %min
+ %min8 = trunc <8 x i32> %min to <8 x i8>
+ %sel = select <8 x i1> %cmp, <8 x i8> %a, <8 x i8> %min8
+ ret <8 x i8> %sel
+}
+
+define <8 x i32> @look_through_cast_int_max_vec(<8 x i32> %a, <8 x i64> %max) {
+; CHECK-LABEL: @look_through_cast_int_max_vec(
+; CHECK-NEXT: [[A32:%.*]] = zext <8 x i32> [[A:%.*]] to <8 x i64>
+; CHECK-NEXT: [[SEL1:%.*]] = call <8 x i64> @llvm.smax.v8i64(<8 x i64> [[MAX:%.*]], <8 x i64> [[A32]])
+; CHECK-NEXT: [[SEL:%.*]] = trunc <8 x i64> [[SEL1]] to <8 x i32>
+; CHECK-NEXT: ret <8 x i32> [[SEL]]
+;
+ %a32 = zext <8 x i32> %a to <8 x i64>
+ %cmp = icmp sgt <8 x i64> %a32, %max
+ %max8 = trunc <8 x i64> %max to <8 x i32>
+ %sel = select <8 x i1> %cmp, <8 x i32> %a, <8 x i32> %max8
+ ret <8 x i32> %sel
+}
+
; Remove a min/max op in a sequence with a common operand.
; PR35717: https://bugs.llvm.org/show_bug.cgi?id=35717
More information about the llvm-commits
mailing list