[llvm] Re apply 130577 narrow math for and operand (PR #133896)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 10 20:29:01 PDT 2025


================
@@ -1559,6 +1559,86 @@ void AMDGPUCodeGenPrepareImpl::expandDivRem64(BinaryOperator &I) const {
   llvm_unreachable("not a division");
 }
 
+/*
+This will cause non-byte load in consistency, for example:
+```
+    %load = load i1, ptr addrspace(4) %arg, align 4
+    %zext = zext i1 %load to
+    i64 %add = add i64 %zext
+```
+Instead of creating `s_and_b32 s0, s0, 1`,
+it will create `s_and_b32 s0, s0, 0xff`.
+We accept this change since the non-byte load assumes the upper bits
+within the byte are all 0.
+*/
+static bool tryNarrowMathIfNoOverflow(Instruction *I,
+                                      const SITargetLowering *TLI,
+                                      const TargetTransformInfo &TTI,
+                                      const DataLayout &DL) {
+  unsigned Opc = I->getOpcode();
+  Type *OldType = I->getType();
+
+  if (Opc != Instruction::Add && Opc != Instruction::Mul)
+    return false;
+
+  unsigned OrigBit = OldType->getScalarSizeInBits();
+  unsigned MaxBitsNeeded = OrigBit;
+
+  switch (Opc) {
+  case Instruction::Add:
+    MaxBitsNeeded = KnownBits::add(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  case Instruction::Mul:
+    MaxBitsNeeded = KnownBits::mul(computeKnownBits(I->getOperand(0), DL),
+                                   computeKnownBits(I->getOperand(1), DL))
+                        .countMaxActiveBits();
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode, only valid for Instruction::Add and "
+                     "Instruction::Mul.");
+  }
+
+  MaxBitsNeeded = std::max<unsigned>(bit_ceil(MaxBitsNeeded), 8);
+  Type *NewType =
+      DL.getSmallestLegalIntType(I->getType()->getContext(), MaxBitsNeeded);
+  if (!NewType)
+    return false;
+  unsigned NewBit = NewType->getIntegerBitWidth();
+  if (NewBit >= OrigBit)
+    return false;
+  NewType = I->getType()->getWithNewBitWidth(NewBit);
----------------
Shoreshen wrote:

Hi this is mainly preventing NewBit = OrigBit, which causes a fatal when creating trunc instruction......

https://github.com/llvm/llvm-project/pull/133896


More information about the llvm-commits mailing list