[llvm] [Matrix] Adjust lifetime.ends during multiply fusion. (PR #84914)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 12 06:54:54 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/84914

>From 01dac795801ce6594b35d7b44c011dd18ae0ec78 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 12 Mar 2024 13:26:38 +0000
Subject: [PATCH] [Matrix] Adjust lifetime.ends during multiply fusion.

At the moment, loads introduced by multiply fusion may be placed after
an objects lifetime has been terminated by lifetime.end. This introduces
reads to dead objects.

To avoid this, first collect all lifetime.end calls in the function.
During fusion, we deal with any lifetime.end calls that may alias any of
the loads.

Such lifetime.end calls are either moved when possible (both the
lifetime.end and the store are in the same block) or deleted.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 40 +++++++++++++++++--
 .../multiply-fused-lifetime-ends.ll           | 21 +++-------
 2 files changed, 43 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 67c011b747acfd..98ca11b4838f7e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -990,12 +990,15 @@ class LowerMatrixIntrinsics {
     bool Changed = false;
     SmallVector<CallInst *, 16> MaybeFusableInsts;
     SmallVector<Instruction *, 16> MatrixInsts;
+    SmallSetVector<IntrinsicInst *, 16> LifetimeEnds;
 
     // First, collect all instructions with shape information and candidates for
     // fusion (currently only matrix multiplies).
     ReversePostOrderTraversal<Function *> RPOT(&Func);
     for (auto *BB : RPOT)
       for (Instruction &I : *BB) {
+        if (match(&I, m_Intrinsic<Intrinsic::lifetime_end>()))
+          LifetimeEnds.insert(cast<IntrinsicInst>(&I));
         if (ShapeMap.find(&I) == ShapeMap.end())
           continue;
         if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
@@ -1010,7 +1013,7 @@ class LowerMatrixIntrinsics {
 
     // Third, try to fuse candidates.
     for (CallInst *CI : MaybeFusableInsts)
-      LowerMatrixMultiplyFused(CI, FusedInsts);
+      LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
 
     Changed = !FusedInsts.empty();
 
@@ -1856,8 +1859,10 @@ class LowerMatrixIntrinsics {
   ///
   /// Call finalizeLowering on lowered instructions.  Instructions that are
   /// completely eliminated by fusion are added to \p FusedInsts.
-  void LowerMatrixMultiplyFused(CallInst *MatMul,
-                                SmallPtrSetImpl<Instruction *> &FusedInsts) {
+  void
+  LowerMatrixMultiplyFused(CallInst *MatMul,
+                           SmallPtrSetImpl<Instruction *> &FusedInsts,
+                           SmallSetVector<IntrinsicInst *, 16> &LifetimeEnds) {
     if (!FuseMatrix || !DT)
       return;
 
@@ -1946,6 +1951,35 @@ class LowerMatrixIntrinsics {
       for (Instruction *I : ToHoist)
         I->moveBefore(MatMul);
 
+      // Deal with lifetime.end calls that might be between Load0/Load1 and the
+      // store. To avoid introducing loads to dead objects (i.e. after thei
+      // lifetime has been termined by @llvm.lifetime.end), either sink them
+      // after the store if in the same block, or remove the lifetime.end marker
+      // otherwise. This might pessimize further optimizations, by extending the
+      // lifetime of the object until the function returns, but should be
+      // conservatively correct.
+      MemoryLocation Load0Loc = MemoryLocation::get(LoadOp0);
+      MemoryLocation Load1Loc = MemoryLocation::get(LoadOp1);
+      for (IntrinsicInst *End : make_early_inc_range(LifetimeEnds)) {
+        if (DT->dominates(Store, End))
+          continue;
+        MemoryLocation EndLoc = MemoryLocation::getForArgument(End, 1, nullptr);
+        if (AA->isNoAlias(Load0Loc, EndLoc) && AA->isNoAlias(Load1Loc, EndLoc))
+          continue;
+
+        // If both lifetime.end and the store are in the same block, extend the
+        // lifetime until after the store, so the new lifetime covers the loads
+        // we introduce later.
+        if (Store->getParent() == End->getParent()) {
+          End->moveAfter(Store);
+          continue;
+        }
+
+        // Otherwise remove the conflicting lifetime.end marker.
+        ToRemove.push_back(End);
+        LifetimeEnds.remove(End);
+      }
+
       emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
       return;
     }
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
index ef8665b7969097..9c2b75f5d5756a 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/multiply-fused-lifetime-ends.ll
@@ -6,15 +6,11 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
 ; Tests to make sure no loads are introduced after a lifetime.end by multiply
 ; fusion.
 
-; FIXME: Currently the tests are mis-compiled, with loads being introduced after
-;       llvm.lifetime.end calls.
-
 define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias %C) {
 ; CHECK-LABEL: @lifetime_for_first_arg_before_multiply(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A:%.*]] = alloca <4 x double>, align 32
 ; CHECK-NEXT:    call void @init(ptr [[A]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -77,6 +73,7 @@ define void @lifetime_for_first_arg_before_multiply(ptr noalias %B, ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -95,7 +92,6 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[B:%.*]] = alloca <4 x double>, align 32
 ; CHECK-NEXT:    call void @init(ptr [[B]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -158,6 +154,7 @@ define void @lifetime_for_second_arg_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -177,7 +174,6 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
 ; CHECK-NEXT:    [[A:%.*]] = alloca <8 x double>, align 64
 ; CHECK-NEXT:    call void @init(ptr [[A]])
 ; CHECK-NEXT:    [[GEP_8:%.*]] = getelementptr i8, ptr [[A]], i64 8
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[GEP_8]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
@@ -240,6 +236,7 @@ define void @lifetime_for_first_arg_before_multiply_load_from_offset(ptr noalias
 ; CHECK-NEXT:    store <2 x double> [[TMP13]], ptr [[TMP26]], align 8
 ; CHECK-NEXT:    [[VEC_GEP28:%.*]] = getelementptr double, ptr [[TMP26]], i64 2
 ; CHECK-NEXT:    store <2 x double> [[TMP25]], ptr [[VEC_GEP28]], align 8
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -261,7 +258,6 @@ define void @lifetime_for_first_arg_before_multiply_lifetime_does_not_dominate(p
 ; CHECK-NEXT:    call void @init(ptr [[A]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
@@ -352,7 +348,6 @@ define void @lifetime_for_second_arg_before_multiply_lifetime_does_not_dominate(
 ; CHECK-NEXT:    call void @init(ptr [[B]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
@@ -441,10 +436,9 @@ define void @lifetime_for_ptr_first_arg_before_multiply(ptr noalias %A, ptr noal
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
@@ -528,15 +522,13 @@ define void @lifetime_for_both_ptr_args_before_multiply(ptr noalias %A, ptr noal
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[B:%.*]])
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[A:%.*]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A]], i64 0
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[A:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x double>, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr double, ptr [[TMP0]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x double>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B]], i64 0
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr double, ptr [[B:%.*]], i64 0
 ; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x double>, ptr [[TMP1]], align 8
 ; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr double, ptr [[TMP1]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x double>, ptr [[VEC_GEP3]], align 8
@@ -618,7 +610,6 @@ define void @lifetime_for_ptr_select_before_multiply(ptr noalias %A, ptr noalias
 ; CHECK-NEXT:    [[P:%.*]] = select i1 [[C_0:%.*]], ptr [[A:%.*]], ptr [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[C_1:%.*]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    call void @llvm.lifetime.end.p0(i64 -1, ptr [[P]])
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr double, ptr [[P]], i64 0



More information about the llvm-commits mailing list