[llvm] [LLVM][IR] Teach constant integer binop folds about vector ConstantInts. (PR #115739)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 11 08:32:48 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

The existing logic mostly works with the main changes being:
 * Use getScalarSizeInBits instead of IntegerType::getBitWidth
 * Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext

---
Full diff: https://github.com/llvm/llvm-project/pull/115739.diff


10 Files Affected:

- (modified) llvm/lib/IR/ConstantFold.cpp (+23-27) 
- (modified) llvm/lib/IR/Constants.cpp (+7) 
- (modified) llvm/test/Transforms/InstCombine/add.ll (+1) 
- (modified) llvm/test/Transforms/InstCombine/div.ll (+1) 
- (modified) llvm/test/Transforms/InstCombine/mul.ll (+1) 
- (modified) llvm/test/Transforms/InstCombine/or.ll (+11-5) 
- (modified) llvm/test/Transforms/InstCombine/rotate.ll (+1) 
- (modified) llvm/test/Transforms/InstCombine/shift.ll (+1) 
- (modified) llvm/test/Transforms/InstCombine/xor-ashr.ll (+2) 
- (modified) llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll (+1-2) 


``````````diff
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index cfe87937c372cd..2dbc6785c08b9d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     return nullptr;
   case Instruction::ZExt:
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
-      uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
-      return ConstantInt::get(V->getContext(),
-                              CI->getValue().zext(BitWidth));
+      uint32_t BitWidth = DestTy->getScalarSizeInBits();
+      return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth));
     }
     return nullptr;
   case Instruction::SExt:
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
-      uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
-      return ConstantInt::get(V->getContext(),
-                              CI->getValue().sext(BitWidth));
+      uint32_t BitWidth = DestTy->getScalarSizeInBits();
+      return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth));
     }
     return nullptr;
   case Instruction::Trunc: {
-    if (V->getType()->isVectorTy())
-      return nullptr;
-
-    uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
-      return ConstantInt::get(V->getContext(),
-                              CI->getValue().trunc(DestBitWidth));
+      uint32_t BitWidth = DestTy->getScalarSizeInBits();
+      return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth));
     }
 
     return nullptr;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
       default:
         break;
       case Instruction::Add:
-        return ConstantInt::get(CI1->getContext(), C1V + C2V);
+        return ConstantInt::get(C1->getType(), C1V + C2V);
       case Instruction::Sub:
-        return ConstantInt::get(CI1->getContext(), C1V - C2V);
+        return ConstantInt::get(C1->getType(), C1V - C2V);
       case Instruction::Mul:
-        return ConstantInt::get(CI1->getContext(), C1V * C2V);
+        return ConstantInt::get(C1->getType(), C1V * C2V);
       case Instruction::UDiv:
         assert(!CI2->isZero() && "Div by zero handled above");
-        return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V));
+        return ConstantInt::get(CI1->getType(), C1V.udiv(C2V));
       case Instruction::SDiv:
         assert(!CI2->isZero() && "Div by zero handled above");
         if (C2V.isAllOnes() && C1V.isMinSignedValue())
           return PoisonValue::get(CI1->getType());   // MIN_INT / -1 -> poison
-        return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V));
+        return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V));
       case Instruction::URem:
         assert(!CI2->isZero() && "Div by zero handled above");
-        return ConstantInt::get(CI1->getContext(), C1V.urem(C2V));
+        return ConstantInt::get(C1->getType(), C1V.urem(C2V));
       case Instruction::SRem:
         assert(!CI2->isZero() && "Div by zero handled above");
         if (C2V.isAllOnes() && C1V.isMinSignedValue())
-          return PoisonValue::get(CI1->getType());   // MIN_INT % -1 -> poison
-        return ConstantInt::get(CI1->getContext(), C1V.srem(C2V));
+          return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison
+        return ConstantInt::get(C1->getType(), C1V.srem(C2V));
       case Instruction::And:
-        return ConstantInt::get(CI1->getContext(), C1V & C2V);
+        return ConstantInt::get(C1->getType(), C1V & C2V);
       case Instruction::Or:
-        return ConstantInt::get(CI1->getContext(), C1V | C2V);
+        return ConstantInt::get(C1->getType(), C1V | C2V);
       case Instruction::Xor:
-        return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
+        return ConstantInt::get(C1->getType(), C1V ^ C2V);
       case Instruction::Shl:
         if (C2V.ult(C1V.getBitWidth()))
-          return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+          return ConstantInt::get(C1->getType(), C1V.shl(C2V));
         return PoisonValue::get(C1->getType()); // too big shift is poison
       case Instruction::LShr:
         if (C2V.ult(C1V.getBitWidth()))
-          return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+          return ConstantInt::get(C1->getType(), C1V.lshr(C2V));
         return PoisonValue::get(C1->getType()); // too big shift is poison
       case Instruction::AShr:
         if (C2V.ult(C1V.getBitWidth()))
-          return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+          return ConstantInt::get(C1->getType(), C1V.ashr(C2V));
         return PoisonValue::get(C1->getType()); // too big shift is poison
       }
     }
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
         return ConstantFP::get(C1->getContext(), C3V);
       }
     }
-  } else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
+  }
+
+  if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
     // Fast path for splatted constants.
     if (Constant *C2Splat = C2->getSplatValue()) {
       if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 7ae397871bdea2..3d6c4ad780dc24 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -441,6 +441,13 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
                ? CAZ->getElementValue(Elt)
                : nullptr;
 
+  if (const auto *CI = dyn_cast<ConstantInt>(this))
+    return Elt < cast<VectorType>(getType())
+                       ->getElementCount()
+                       .getKnownMinValue()
+               ? ConstantInt::get(getContext(), CI->getValue())
+               : nullptr;
+
   // FIXME: getNumElements() will fail for non-fixed vector types.
   if (isa<ScalableVectorType>(getType()))
     return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 4b1159cf07e710..4825e588aa0856 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 declare void @use(i8)
 declare void @use_i1(i1)
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index 33a8e12dfa1a68..6344966d6cac3b 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 declare void @use(i32)
 
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..e3108fc54c4f4c 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 declare i32 @llvm.abs.i32(i32, i1)
 
diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll
index 4a886afd78a5f0..95f89e4ce11cd5 100644
--- a/llvm/test/Transforms/InstCombine/or.ll
+++ b/llvm/test/Transforms/InstCombine/or.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC
+; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT
 
 target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n32:64"
 declare void @use(i32)
@@ -399,10 +400,15 @@ define i32 @test30(i32 %A) {
 }
 
 define <2 x i32> @test30vec(<2 x i32> %A) {
-; CHECK-LABEL: @test30vec(
-; CHECK-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
-; CHECK-NEXT:    [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
-; CHECK-NEXT:    ret <2 x i32> [[E]]
+; CONSTVEC-LABEL: @test30vec(
+; CONSTVEC-NEXT:    [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTVEC-NEXT:    [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
+; CONSTVEC-NEXT:    ret <2 x i32> [[E]]
+;
+; CONSTSPLAT-LABEL: @test30vec(
+; CONSTSPLAT-NEXT:    [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
+; CONSTSPLAT-NEXT:    [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962)
+; CONSTSPLAT-NEXT:    ret <2 x i32> [[E]]
 ;
   %B = or <2 x i32> %A, <i32 32962, i32 32962>
   %C = and <2 x i32> %A, <i32 -65536, i32 -65536>
diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll
index ea7c471594da0a..bae50736de0c33 100644
--- a/llvm/test/Transforms/InstCombine/rotate.ll
+++ b/llvm/test/Transforms/InstCombine/rotate.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128"
 
diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll
index d2ee97f39123b0..d72a1849c7dfd6 100644
--- a/llvm/test/Transforms/InstCombine/shift.ll
+++ b/llvm/test/Transforms/InstCombine/shift.ll
@@ -1,5 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
 
 declare void @use(i64)
 declare void @use_i32(i32)
diff --git a/llvm/test/Transforms/InstCombine/xor-ashr.ll b/llvm/test/Transforms/InstCombine/xor-ashr.ll
index 0c0554adcf1230..f5ccdeef2f382b 100644
--- a/llvm/test/Transforms/InstCombine/xor-ashr.ll
+++ b/llvm/test/Transforms/InstCombine/xor-ashr.ll
@@ -1,5 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
+
 target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
 
 declare void @use16(i16)
diff --git a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
index 3f1672d66abf0d..b475b8199541d5 100644
--- a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
+++ b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll
@@ -81,8 +81,7 @@ define <1 x i1> @test10() {
 ; CONSTVEC-NEXT:    ret <1 x i1> [[RET]]
 ;
 ; CONSTSPLAT-LABEL: @test10(
-; CONSTSPLAT-NEXT:    [[RET:%.*]] = icmp eq <1 x i64> splat (i64 -1), zeroinitializer
-; CONSTSPLAT-NEXT:    ret <1 x i1> [[RET]]
+; CONSTSPLAT-NEXT:    ret <1 x i1> zeroinitializer
 ;
   %ret = icmp eq <1 x i64> <i64 bitcast (<1 x double> <double 0xFFFFFFFFFFFFFFFF> to i64)>, zeroinitializer
   ret <1 x i1> %ret

``````````

</details>


https://github.com/llvm/llvm-project/pull/115739


More information about the llvm-commits mailing list