[llvm] [ArgPromotion] Remove redundant logic from recursive argpromotion code (PR #98657)

Vedant Paranjape via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 09:24:15 PDT 2024


https://github.com/vedantparanjape-amd created https://github.com/llvm/llvm-project/pull/98657

This patch further cleans up the implementation by removing some redundant checks and replacing cast<> with get() calls. It adds a check to see if function call type matches the function type.

This contribution is based on the discussion in #78735 

>From db01c69129b2d5aa3b789315c79441e326c80d75 Mon Sep 17 00:00:00 2001
From: Vedant Paranjape <vedant.paranjape at amd.com>
Date: Fri, 12 Jul 2024 12:28:08 +0000
Subject: [PATCH] [ArgPromotion] Remove redundant logic from recursive
 argpromotion code

This patch further cleans up the implementation by removing some
redundant checks and replacing cast<> with get() calls. It adds a check
to see if function call type matches the function type.
---
 llvm/lib/Transforms/IPO/ArgumentPromotion.cpp |  8 ++-
 .../recursion/recursion-diff-call-types.ll    | 68 +++++++++++++++++++
 2 files changed, 73 insertions(+), 3 deletions(-)
 create mode 100644 llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll

diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
index 77dbf349df0df..78805f7fc9554 100644
--- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
+++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
@@ -640,8 +640,10 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
     }
 
     auto *CB = dyn_cast<CallBase>(V);
-    Value *PtrArg = cast<Value>(U);
-    if (CB && PtrArg && CB->getCalledFunction() == CB->getFunction()) {
+    Value *PtrArg = U->get();
+    if (CB && CB->getCalledFunction() == CB->getFunction() &&
+        CB->getCalledFunction()->getReturnType() ==
+            CB->getFunction()->getReturnType()) {
       if (PtrArg != Arg) {
         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                           << "pointer offset is not equal to zero\n");
@@ -649,7 +651,7 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
       }
 
       unsigned int ArgNo = Arg->getArgNo();
-      if (CB->getArgOperand(ArgNo) != Arg || U->getOperandNo() != ArgNo) {
+      if (U->getOperandNo() != ArgNo) {
         LLVM_DEBUG(dbgs() << "ArgPromotion of " << *Arg << " failed: "
                           << "arg position is different in callee\n");
         return false;
diff --git a/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll b/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll
new file mode 100644
index 0000000000000..a4ee73727108a
--- /dev/null
+++ b/llvm/test/Transforms/ArgumentPromotion/recursion/recursion-diff-call-types.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=argpromotion < %s | FileCheck %s
+define internal i32 @foo(ptr %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define internal i32 @foo(
+; CHECK-SAME: ptr [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i32 [[N]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_FALSE:.*]]
+; CHECK:       [[COND_TRUE]]:
+; CHECK-NEXT:    [[VAL:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    br label %[[RETURN:.*]]
+; CHECK:       [[COND_FALSE]]:
+; CHECK-NEXT:    [[VAL2:%.*]] = load i32, ptr [[X]], align 4
+; CHECK-NEXT:    [[SUBVAL:%.*]] = sub i32 [[N]], 1
+; CHECK-NEXT:    [[CALLRET0:%.*]] = call float @foo(ptr [[X]], i32 [[SUBVAL]], i32 [[VAL2]])
+; CHECK-NEXT:    [[CALLRET1:%.*]] = call i32 @foo(ptr [[X]], i32 [[SUBVAL]], i32 [[VAL2]])
+; CHECK-NEXT:    [[SUBVAL2:%.*]] = sub i32 [[N]], 2
+; CHECK-NEXT:    [[CALLRET2:%.*]] = call i32 @foo(ptr [[X]], i32 [[SUBVAL2]], i32 [[M]])
+; CHECK-NEXT:    [[CMP2:%.*]] = add i32 [[CALLRET1]], [[CALLRET2]]
+; CHECK-NEXT:    br label %[[RETURN]]
+; CHECK:       [[COND_NEXT:.*]]:
+; CHECK-NEXT:    br label %[[RETURN]]
+; CHECK:       [[RETURN]]:
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = phi i32 [ [[VAL]], %[[COND_TRUE]] ], [ [[CMP2]], %[[COND_FALSE]] ], [ poison, %[[COND_NEXT]] ]
+; CHECK-NEXT:    ret i32 [[RETVAL_0]]
+;
+entry:
+  %cmp = icmp ne i32 %n, 0
+  br i1 %cmp, label %cond_true, label %cond_false
+
+cond_true:                                        ; preds = %entry
+  %val = load i32, ptr %x, align 4
+  br label %return
+
+cond_false:                                       ; preds = %entry
+  %val2 = load i32, ptr %x, align 4
+  %subval = sub i32 %n, 1
+  %callret0 = call float @foo(ptr %x, i32 %subval, i32 %val2)
+  %callret1 = call i32 @foo(ptr %x, i32 %subval, i32 %val2)
+  %subval2 = sub i32 %n, 2
+  %callret2 = call i32 @foo(ptr %x, i32 %subval2, i32 %m)
+  %cmp2 = add i32 %callret1, %callret2
+  br label %return
+
+cond_next:                                        ; No predecessors!
+  br label %return
+
+return:                                           ; preds = %cond_next, %cond_false, %cond_true
+  %retval.0 = phi i32 [ %val, %cond_true ], [ %cmp2, %cond_false ], [ poison, %cond_next ]
+  ret i32 %retval.0
+}
+
+define i32 @bar(ptr align(4) dereferenceable(4) %x, i32 %n, i32 %m) {
+; CHECK-LABEL: define i32 @bar(
+; CHECK-SAME: ptr align 4 dereferenceable(4) [[X:%.*]], i32 [[N:%.*]], i32 [[M:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CALLRET3:%.*]] = call i32 @foo(ptr [[X]], i32 [[N]], i32 [[M]])
+; CHECK-NEXT:    br label %[[RETURN:.*]]
+; CHECK:       [[RETURN]]:
+; CHECK-NEXT:    ret i32 [[CALLRET3]]
+;
+entry:
+  %callret3 = call i32 @foo(ptr %x, i32 %n, i32 %m)
+  br label %return
+
+return:                                           ; preds = %entry
+  ret i32 %callret3
+}



More information about the llvm-commits mailing list