[llvm] [LLVM][IR] Teach constant integer binop folds about vector ConstantInts. (PR #115739)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 11 08:32:48 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-llvm-transforms
Author: Paul Walker (paulwalker-arm)
<details>
<summary>Changes</summary>
The existing logic mostly works with the main changes being:
* Use getScalarSizeInBits instead of IntegerType::getBitWidth
* Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext
---
Full diff: https://github.com/llvm/llvm-project/pull/115739.diff
10 Files Affected:
- (modified) llvm/lib/IR/ConstantFold.cpp (+23-27)
- (modified) llvm/lib/IR/Constants.cpp (+7)
- (modified) llvm/test/Transforms/InstCombine/add.ll (+1)
- (modified) llvm/test/Transforms/InstCombine/div.ll (+1)
- (modified) llvm/test/Transforms/InstCombine/mul.ll (+1)
- (modified) llvm/test/Transforms/InstCombine/or.ll (+11-5)
- (modified) llvm/test/Transforms/InstCombine/rotate.ll (+1)
- (modified) llvm/test/Transforms/InstCombine/shift.ll (+1)
- (modified) llvm/test/Transforms/InstCombine/xor-ashr.ll (+2)
- (modified) llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll (+1-2)
``````````diff
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index cfe87937c372cd..2dbc6785c08b9d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
return nullptr;
case Instruction::ZExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().zext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth));
}
return nullptr;
case Instruction::SExt:
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
- return ConstantInt::get(V->getContext(),
- CI->getValue().sext(BitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth));
}
return nullptr;
case Instruction::Trunc: {
- if (V->getType()->isVectorTy())
- return nullptr;
-
- uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- return ConstantInt::get(V->getContext(),
- CI->getValue().trunc(DestBitWidth));
+ uint32_t BitWidth = DestTy->getScalarSizeInBits();
+ return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth));
}
return nullptr;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
default:
break;
case Instruction::Add:
- return ConstantInt::get(CI1->getContext(), C1V + C2V);
+ return ConstantInt::get(C1->getType(), C1V + C2V);
case Instruction::Sub:
- return ConstantInt::get(CI1->getContext(), C1V - C2V);
+ return ConstantInt::get(C1->getType(), C1V - C2V);
case Instruction::Mul:
- return ConstantInt::get(CI1->getContext(), C1V * C2V);
+ return ConstantInt::get(C1->getType(), C1V * C2V);
case Instruction::UDiv:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.udiv(C2V));
case Instruction::SDiv:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V));
+ return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V));
case Instruction::URem:
assert(!CI2->isZero() && "Div by zero handled above");
- return ConstantInt::get(CI1->getContext(), C1V.urem(C2V));
+ return ConstantInt::get(C1->getType(), C1V.urem(C2V));
case Instruction::SRem:
assert(!CI2->isZero() && "Div by zero handled above");
if (C2V.isAllOnes() && C1V.isMinSignedValue())
- return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison
- return ConstantInt::get(CI1->getContext(), C1V.srem(C2V));
+ return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison
+ return ConstantInt::get(C1->getType(), C1V.srem(C2V));
case Instruction::And:
- return ConstantInt::get(CI1->getContext(), C1V & C2V);
+ return ConstantInt::get(C1->getType(), C1V & C2V);
case Instruction::Or:
- return ConstantInt::get(CI1->getContext(), C1V | C2V);
+ return ConstantInt::get(C1->getType(), C1V | C2V);
case Instruction::Xor:
- return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
+ return ConstantInt::get(C1->getType(), C1V ^ C2V);
case Instruction::Shl:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+ return ConstantInt::get(C1->getType(), C1V.shl(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::LShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.lshr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
case Instruction::AShr:
if (C2V.ult(C1V.getBitWidth()))
- return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+ return ConstantInt::get(C1->getType(), C1V.ashr(C2V));
return PoisonValue::get(C1->getType()); // too big shift is poison
}
}
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
return ConstantFP::get(C1->getContext(), C3V);
}
}
- } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
+ }
+
+ if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
// Fast path for splatted constants.
if (Constant *C2Splat = C2->getSplatValue()) {
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 7ae397871bdea2..3d6c4ad780dc24 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -441,6 +441,13 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
? CAZ->getElementValue(Elt)
: nullptr;
+ if (const auto *CI = dyn_cast<ConstantInt>(this))
+ return Elt < cast<VectorType>(getType())
+ ->getElementCount()
+ .getKnownMinValue()
+ ? ConstantInt::get(getContext(), CI->getValue())
+ : nullptr;
+
// FIXME: getNumElements() will fail for non-fixed vector types.
if (isa<ScalableVectorType>(getType()))
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 4b1159cf07e710..4825e588aa0856 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i8)
declare void @use_i1(i1)
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index 33a8e12dfa1a68..6344966d6cac3b 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i32)
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..e3108fc54c4f4c 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare i32 @llvm.abs.i32(i32, i1)
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 4a886afd78a5f0..95f89e4ce11cd5 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC
+; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n32:64"
declare void @use(i32)
@@ -399,10 +400,15 @@ define i32 @test30(i32 %A) {
}
define <2 x i32> @test30vec(<2 x i32> %A) {
-; CHECK-LABEL: @test30vec(
-; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
-; CHECK-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
-; CHECK-NEXT: ret <2 x i32> [[E]]
+; CONSTVEC-LABEL: @test30vec(
+; CONSTVEC-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTVEC-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
+; CONSTVEC-NEXT: ret <2 x i32> [[E]]
+;
+; CONSTSPLAT-LABEL: @test30vec(
+; CONSTSPLAT-NEXT: [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTSPLAT-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962)
+; CONSTSPLAT-NEXT: ret <2 x i32> [[E]]
;
%B = or <2 x i32> %A, <i32 32962, i32 32962>
%C = and <2 x i32> %A, <i32 -65536, i32 -65536>
diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll
index ea7c471594da0a..bae50736de0c33 100644
--- a/llvm/test/Transforms/InstCombine/rotate.ll
+++ b/llvm/test/Transforms/InstCombine/rotate.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128"
diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index d2ee97f39123b0..d72a1849c7dfd6 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
declare void @use(i64)
declare void @use_i32(i32)
diff --git a/llvm/test/Transforms/InstCombine/xor-ashr.ll b/llvm/test/Transforms/InstCombine/xor-ashr.ll
index 0c0554adcf1230..f5ccdeef2f382b 100644
--- a/llvm/test/Transforms/InstCombine/xor-ashr.ll
+++ b/llvm/test/Transforms/InstCombine/xor-ashr.ll
@@ -1,5 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
+
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
declare void @use16(i16)
diff --git a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
index 3f1672d66abf0d..b475b8199541d5 100644
--- a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
+++ b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
@@ -81,8 +81,7 @@ define <1 x i1> @test10() {
; CONSTVEC-NEXT: ret <1 x i1> [[RET]]
;
; CONSTSPLAT-LABEL: @test10(
-; CONSTSPLAT-NEXT: [[RET:%.*]] = icmp eq <1 x i64> splat (i64 -1), zeroinitializer
-; CONSTSPLAT-NEXT: ret <1 x i1> [[RET]]
+; CONSTSPLAT-NEXT: ret <1 x i1> zeroinitializer
;
%ret = icmp eq <1 x i64> <i64 bitcast (<1 x double> <double 0xFFFFFFFFFFFFFFFF> to i64)>, zeroinitializer
ret <1 x i1> %ret
``````````
</details>
https://github.com/llvm/llvm-project/pull/115739
More information about the llvm-commits
mailing list