[llvm] 48441cb - [Matrix] Properly set Changed status when optimizing transposes.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 6 09:37:09 PDT 2025


Author: Florian Hahn
Date: 2025-04-06T17:36:56+01:00
New Revision: 48441cb8a2fa3b3f9502ba4ba1242746615841cb

URL: https://github.com/llvm/llvm-project/commit/48441cb8a2fa3b3f9502ba4ba1242746615841cb
DIFF: https://github.com/llvm/llvm-project/commit/48441cb8a2fa3b3f9502ba4ba1242746615841cb.diff

LOG: [Matrix] Properly set Changed status when optimizing transposes.

Currently Changed is not updated properly when transposes are optimized,
causing missing analysis invalidation. Update optimizeTransposes to
indicate if changes have been made.

Added: 
    llvm/test/Transforms/LowerMatrixIntrinsics/analysis-invalidation.ll

Modified: 
    llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8a30a3e8d22e2..ab16ec77be105 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -792,7 +792,8 @@ class LowerMatrixIntrinsics {
   /// This creates and erases instructions as needed, and returns the newly
   /// created instruction while updating the iterator to avoid invalidation. If
   /// this returns nullptr, no new instruction was created.
-  Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
+  Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II,
+                             bool &Changed) {
     BasicBlock &BB = *I.getParent();
     IRBuilder<> IB(&I);
     MatrixBuilder Builder(IB);
@@ -809,6 +810,7 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, TATA);
       eraseFromParentAndMove(&I, II, BB);
       eraseFromParentAndMove(TA, II, BB);
+      Changed = true;
       return nullptr;
     }
 
@@ -816,6 +818,7 @@ class LowerMatrixIntrinsics {
     if (isSplat(TA)) {
       updateShapeAndReplaceAllUsesWith(I, TA);
       eraseFromParentAndMove(&I, II, BB);
+      Changed = true;
       return nullptr;
     }
 
@@ -834,6 +837,7 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       eraseFromParentAndMove(&I, II, BB);
       eraseFromParentAndMove(TA, II, BB);
+      Changed = true;
       return NewInst;
     }
 
@@ -859,6 +863,7 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       eraseFromParentAndMove(&I, II, BB);
       eraseFromParentAndMove(TA, II, BB);
+      Changed = true;
       return NewInst;
     }
 
@@ -880,13 +885,14 @@ class LowerMatrixIntrinsics {
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       eraseFromParentAndMove(&I, II, BB);
       eraseFromParentAndMove(TA, II, BB);
+      Changed = true;
       return NewInst;
     }
 
     return nullptr;
   }
 
-  void liftTranspose(Instruction &I) {
+  bool liftTranspose(Instruction &I) {
     // Erase dead Instructions after lifting transposes from binops.
     auto CleanupBinOp = [this](Instruction &T, Value *A, Value *B) {
       if (T.use_empty())
@@ -914,6 +920,7 @@ class LowerMatrixIntrinsics {
                                                            R->getZExtValue());
       updateShapeAndReplaceAllUsesWith(I, NewInst);
       CleanupBinOp(I, A, B);
+      return true;
     }
     // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
     // the shape of the second transpose is 
diff erent, there's a shape conflict
@@ -940,11 +947,14 @@ class LowerMatrixIntrinsics {
                 ShapeMap[AddI] &&
             "Shape of updated addition doesn't match cached shape.");
       }
+      return true;
     }
+    return false;
   }
 
   /// Try moving transposes in order to fold them away or into multiplies.
-  void optimizeTransposes() {
+  bool optimizeTransposes() {
+    bool Changed = false;
     // First sink all transposes inside matmuls and adds, hoping that we end up
     // with NN, NT or TN variants.
     for (BasicBlock &BB : reverse(Func)) {
@@ -952,7 +962,7 @@ class LowerMatrixIntrinsics {
         Instruction &I = *II;
         // We may remove II.  By default continue on the next/prev instruction.
         ++II;
-        if (Instruction *NewInst = sinkTranspose(I, II))
+        if (Instruction *NewInst = sinkTranspose(I, II, Changed))
           II = std::next(BasicBlock::reverse_iterator(NewInst));
       }
     }
@@ -961,9 +971,10 @@ class LowerMatrixIntrinsics {
     // to fold into consuming multiply or add.
     for (BasicBlock &BB : Func) {
       for (Instruction &I : llvm::make_early_inc_range(BB)) {
-        liftTranspose(I);
+        Changed |= liftTranspose(I);
       }
     }
+    return Changed;
   }
 
   bool Visit() {
@@ -1006,15 +1017,15 @@ class LowerMatrixIntrinsics {
       WorkList = propagateShapeBackward(WorkList);
     }
 
+    bool Changed = false;
     if (!isMinimal()) {
-      optimizeTransposes();
+      Changed |= optimizeTransposes();
       if (PrintAfterTransposeOpt) {
         dbgs() << "Dump after matrix transpose optimization:\n";
         Func.print(dbgs());
       }
     }
 
-    bool Changed = false;
     SmallVector<CallInst *, 16> MaybeFusableInsts;
     SmallVector<Instruction *, 16> MatrixInsts;
     SmallVector<IntrinsicInst *, 16> LifetimeEnds;
@@ -1043,7 +1054,7 @@ class LowerMatrixIntrinsics {
       if (!FusedInsts.contains(CI))
         LowerMatrixMultiplyFused(CI, FusedInsts, LifetimeEnds);
 
-    Changed = !FusedInsts.empty();
+    Changed |= !FusedInsts.empty();
 
     // Fourth, lower remaining instructions with shape information.
     for (Instruction *Inst : MatrixInsts) {

diff  --git a/llvm/test/Transforms/LowerMatrixIntrinsics/analysis-invalidation.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/analysis-invalidation.ll
new file mode 100644
index 0000000000000..a747328a71e7a
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/analysis-invalidation.ll
@@ -0,0 +1,17 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -p lower-matrix-intrinsics -verify-analysis-invalidation -S %s | FileCheck %s
+
+define <3 x float> @splat_transpose(<3 x float> %in) {
+; CHECK-LABEL: define <3 x float> @splat_transpose(
+; CHECK-SAME: <3 x float> [[IN:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <3 x float> [[IN]], <3 x float> zeroinitializer, <3 x i32> zeroinitializer
+; CHECK-NEXT:    ret <3 x float> [[SPLAT]]
+;
+entry:
+  %splat = shufflevector <3 x float> %in, <3 x float> zeroinitializer, <3 x i32> zeroinitializer
+  %r = tail call <3 x float> @llvm.matrix.transpose.v3f32(<3 x float> %splat, i32 3, i32 1)
+  ret <3 x float> %r
+}
+
+declare <3 x float> @llvm.matrix.transpose.v3f32(<3 x float>, i32 immarg, i32 immarg)


        


More information about the llvm-commits mailing list