[llvm] 1712ae6 - [AArch64] Improve cost of umull from known bits

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 12 05:13:10 PDT 2023


Author: David Green
Date: 2023-07-12T13:13:06+01:00
New Revision: 1712ae670915fc73ce136508497a727f1b9f4d85

URL: https://github.com/llvm/llvm-project/commit/1712ae670915fc73ce136508497a727f1b9f4d85
DIFF: https://github.com/llvm/llvm-project/commit/1712ae670915fc73ce136508497a727f1b9f4d85.diff

LOG: [AArch64] Improve cost of umull from known bits

As in D140287, we can now generate umull from mul(zext(x), y) in cases where we
know that the top bits of y are zero. This teaches that to the cost model,
adjusting how isWideningInstruction detects mul operations that can extend both
operands. This helps for constants and other cases where the operands of the
mul are known to be extended, but not directly extends.

Differential Revision: https://reviews.llvm.org/D154936

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/test/Analysis/CostModel/AArch64/arith-widening.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 1c591d4931a749..7947ccf0a71b92 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -1944,9 +1944,8 @@ AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
 }
 
 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
-                                           ArrayRef<Type *> SrcTys,
-                                           ArrayRef<const Value *> Args) {
-
+                                           ArrayRef<const Value *> Args,
+                                           Type *SrcOverrideTy) {
   // A helper that returns a vector type from the given type. The number of
   // elements in type Ty determines the vector width.
   auto toVectorTy = [&](Type *ArgTy) {
@@ -1954,12 +1953,14 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
                            cast<VectorType>(DstTy)->getElementCount());
   };
 
-  // Exit early if DstTy is not a vector type whose elements are at least
-  // 16-bits wide. SVE doesn't generally have the same set of instructions to
+  // Exit early if DstTy is not a vector type whose elements are one of [i16,
+  // i32, i64]. SVE doesn't generally have the same set of instructions to
   // perform an extend with the add/sub/mul. There are SMULLB style
   // instructions, but they operate on top/bottom, requiring some sort of lane
   // interleaving to be used with zext/sext.
-  if (!useNeonVector(DstTy) || DstTy->getScalarSizeInBits() < 16)
+  unsigned DstEltSize = DstTy->getScalarSizeInBits();
+  if (!useNeonVector(DstTy) || Args.size() != 2 ||
+      (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
     return false;
 
   // Determine if the operation has a widening variant. We consider both the
@@ -1969,42 +1970,55 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
   // TODO: Add additional widening operations (e.g., shl, etc.) once we
   //       verify that their extending operands are eliminated during code
   //       generation.
+  Type *SrcTy = SrcOverrideTy;
   switch (Opcode) {
   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
-  case Instruction::Mul: // SMULL(2), UMULL(2)
+    // The second operand needs to be an extend
+    if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
+      if (!SrcTy)
+        SrcTy =
+            toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
+    } else
+      return false;
+    break;
+  case Instruction::Mul: { // SMULL(2), UMULL(2)
+    // Both operands need to be extends of the same type.
+    if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
+        (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
+      if (!SrcTy)
+        SrcTy =
+            toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
+    } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
+      // If one of the operands is a Zext and the other has enough zero bits to
+      // be treated as unsigned, we can still general a umull, meaning the zext
+      // is free.
+      KnownBits Known =
+          computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
+      if (Args[0]->getType()->getScalarSizeInBits() -
+              Known.Zero.countLeadingOnes() >
+          DstTy->getScalarSizeInBits() / 2)
+        return false;
+      if (!SrcTy)
+        SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
+                                           DstTy->getScalarSizeInBits() / 2));
+    } else
+      return false;
     break;
+  }
   default:
     return false;
   }
 
-  // To be a widening instruction (either the "wide" or "long" versions), the
-  // second operand must be a sign- or zero extend.
-  if (Args.size() != 2 ||
-      (!isa<SExtInst>(Args[1]) && !isa<ZExtInst>(Args[1])))
-    return false;
-  auto *Extend = cast<CastInst>(Args[1]);
-  auto *Arg0 = dyn_cast<CastInst>(Args[0]);
-
-  // A mul only has a mull version (not like addw). Both operands need to be
-  // extending and the same type.
-  if (Opcode == Instruction::Mul &&
-      (!Arg0 || Arg0->getOpcode() != Extend->getOpcode() ||
-       (SrcTys.size() == 2 && SrcTys[0] != SrcTys[1])))
-    return false;
-
   // Legalize the destination type and ensure it can be used in a widening
   // operation.
   auto DstTyL = getTypeLegalizationCost(DstTy);
-  unsigned DstElTySize = DstTyL.second.getScalarSizeInBits();
-  if (!DstTyL.second.isVector() || DstElTySize != DstTy->getScalarSizeInBits())
+  if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
     return false;
 
   // Legalize the source type and ensure it can be used in a widening
   // operation.
-  Type *SrcTy =
-      SrcTys.size() > 0 ? SrcTys.back() : toVectorTy(Extend->getSrcTy());
-
+  assert(SrcTy && "Expected some SrcTy");
   auto SrcTyL = getTypeLegalizationCost(SrcTy);
   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
@@ -2018,7 +2032,7 @@ bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
 
   // Return true if the legalized types have the same number of vector elements
   // and the destination element type size is twice that of the source type.
-  return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstElTySize;
+  return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
 }
 
 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
@@ -2033,31 +2047,17 @@ InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
   if (I && I->hasOneUser()) {
     auto *SingleUser = cast<Instruction>(*I->user_begin());
     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
-    SmallVector<Type *, 2> SrcTys;
-    for (const Value *Op : Operands) {
-      auto *Cast = dyn_cast<CastInst>(Op);
-      if (!Cast)
-        continue;
-      // Use provided Src type for I and other casts that have the same source
-      // type.
-      if (Op == I || cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy())
-        SrcTys.push_back(Src);
-      else
-        SrcTys.push_back(Cast->getSrcTy());
-    }
-    if (isWideningInstruction(Dst, SingleUser->getOpcode(), SrcTys, Operands)) {
-      // If the cast is the second operand, it is free. We will generate either
-      // a "wide" or "long" version of the widening instruction.
-      if (I == SingleUser->getOperand(1))
-        return 0;
-      // If the cast is not the second operand, it will be free if it looks the
-      // same as the second operand. In this case, we will generate a "long"
-      // version of the widening instruction.
-      if (auto *Cast = dyn_cast<CastInst>(SingleUser->getOperand(1)))
-        if (I->getOpcode() == unsigned(Cast->getOpcode()) &&
-            (Src == Cast->getSrcTy() ||
-             cast<CastInst>(I)->getSrcTy() == Cast->getSrcTy()))
+    if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
+      // For adds only count the second operand as free if both operands are
+      // extends but not the same operation. (i.e both operands are not free in
+      // add(sext, zext)).
+      if (SingleUser->getOpcode() == Instruction::Add) {
+        if (I == SingleUser->getOperand(1) ||
+            (isa<CastInst>(SingleUser->getOperand(1)) &&
+             cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
           return 0;
+      } else // Others are free so long as isWideningInstruction returned true.
+        return 0;
     }
   }
 
@@ -2680,7 +2680,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
     // LT.first = 2 the cost is 28. If both operands are extensions it will not
     // need to scalarize so the cost can be cheaper (smull or umull).
     // so the cost can be cheaper (smull or umull).
-    if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, {}, Args))
+    if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
       return LT.first;
     return LT.first * 14;
   case ISD::ADD:

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index e261892254da10..7cb49126d09107 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -58,8 +58,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   };
 
   bool isWideningInstruction(Type *DstTy, unsigned Opcode,
-                             ArrayRef<Type *> SrcTys,
-                             ArrayRef<const Value *> Args);
+                             ArrayRef<const Value *> Args,
+                             Type *SrcOverrideTy = nullptr);
 
   // A helper function called by 'getVectorInstrCost'.
   //

diff  --git a/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll b/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
index 919d1159aa9329..52f6f73525a3b9 100644
--- a/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
+++ b/llvm/test/Analysis/CostModel/AArch64/arith-widening.ll
@@ -2087,3 +2087,27 @@ define void @extmulv16(<16 x i8> %i8, <16 x i16> %i16, <16 x i32> %i32, <16 x i6
 
   ret void
 }
+
+define void @extmul_const(<8 x i8> %i8, <8 x i16> %i16, <8 x i32> %i32, <8 x i64> %i64)  {
+; CHECK-LABEL: 'extmul_const'
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %sl1_8_16 = sext <8 x i8> %i8 to <8 x i16>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %asl_8_16 = mul <8 x i16> %sl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16 = zext <8 x i8> %i8 to <8 x i16>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %azl_8_16 = mul <8 x i16> %zl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %zl1_8_16b = zext <8 x i8> %i8 to <8 x i16>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %and = and <8 x i16> %sl1_8_16, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 1 for instruction: %aal_8_16 = mul <8 x i16> %zl1_8_16b, %and
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret void
+;
+  %sl1_8_16 = sext <8 x i8> %i8 to <8 x i16>
+  %asl_8_16 = mul <8 x i16> %sl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+
+  %zl1_8_16 = zext <8 x i8> %i8 to <8 x i16>
+  %azl_8_16 = mul <8 x i16> %zl1_8_16, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
+
+  %zl1_8_16b = zext <8 x i8> %i8 to <8 x i16>
+  %and = and <8 x i16> %sl1_8_16, <i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255, i16 255>
+  %aal_8_16 = mul <8 x i16> %zl1_8_16b, %and
+
+  ret void
+}


        


More information about the llvm-commits mailing list