[llvm] [AArch64] Guard against getRegisterBitWidth returning zero in vector instr cost. (PR #117749)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 26 09:15:46 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/117749

If the getRegisterBitWidth is zero (such as in sme streaming functions), then we could hit a crash from using % RegWidth.

It took a while to figure out what was going wrong so there are a few other minor cleanups here too.

>From 8abe6b1d93f414db44d9366aa04602c258c8af46 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 26 Nov 2024 17:09:21 +0000
Subject: [PATCH] [AArch64] Guard against getRegisterBitWidth returning zero in
 vector instr cost.

If the getRegisterBitWidth is zero (such as in sme streaming functions), then
we could hit a crash from using % RegWidth.

It took a while to figure out what was going wrong so there are a few other
minor cleanups here too.
---
 .../AArch64/AArch64TargetTransformInfo.cpp     | 11 ++++++-----
 .../AArch64/extract_float_streaming.ll         | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll

diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7a1e401bca18cb..a6b595d71bfe04 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3248,19 +3248,18 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
     // Check if the extractelement user is scalar fmul.
     auto IsUserFMulScalarTy = [](const Value *EEUser) {
       // Check if the user is scalar fmul.
-      const auto *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      const auto *BO = dyn_cast<BinaryOperator>(EEUser);
       return BO && BO->getOpcode() == BinaryOperator::FMul &&
              !BO->getType()->isVectorTy();
     };
 
     // Check if the extract index is from lane 0 or lane equivalent to 0 for a
     // certain scalar type and a certain vector register width.
-    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
-                                             const unsigned &EltSz) {
+    auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
       auto RegWidth =
           getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
               .getFixedValue();
-      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+      return RegWidth != 0 && (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
     };
 
     // Check if the type constraints on input vector type and result scalar type
@@ -3277,13 +3276,15 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
         // important.
         UserToExtractIdx[U];
       }
+      if (UserToExtractIdx.empty())
+        return false;
       for (auto &[S, U, L] : ScalarUserAndIdx) {
         for (auto *U : S->users()) {
           if (UserToExtractIdx.find(U) != UserToExtractIdx.end()) {
             auto *FMul = cast<BinaryOperator>(U);
             auto *Op0 = FMul->getOperand(0);
             auto *Op1 = FMul->getOperand(1);
-            if ((Op0 == S && Op1 == S) || (Op0 != S) || (Op1 != S)) {
+            if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
               UserToExtractIdx[U] = L;
               break;
             }
diff --git a/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
new file mode 100644
index 00000000000000..84502abceed3b0
--- /dev/null
+++ b/llvm/test/Analysis/CostModel/AArch64/extract_float_streaming.ll
@@ -0,0 +1,18 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes="print<cost-model>" 2>&1 -disable-output -mtriple=aarch64-unknown-linux -mattr=+sme | FileCheck %s
+
+define double @extract_case7(<4 x double> %a) "aarch64_pstate_sm_enabled" {
+; CHECK-LABEL: 'extract_case7'
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %0 = extractelement <4 x double> %a, i32 1
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: %1 = extractelement <4 x double> %a, i32 2
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 2 for instruction: %res = fmul double %0, %1
+; CHECK-NEXT:  Cost Model: Found an estimated cost of 0 for instruction: ret double %res
+;
+entry:
+  %1 = extractelement <4 x double> %a, i32 1
+  %2 = extractelement <4 x double> %a, i32 2
+  %res = fmul double %1, %2
+  ret double %res
+}
+
+declare void @foo(double)



More information about the llvm-commits mailing list