[llvm] [Matrix] Add a Remark when matrices get flattened (PR #142078)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 9 13:40:53 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/142078

>From ce446d015f734b7e24309ca1feb3b694c87eb7d4 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 19:45:04 -0700
Subject: [PATCH] [Matrix] Add -debug-only prints when matrices get flattened

This is a potential source of overhead, which we might be able to alleviate in
some cases. For example, static element extracts, or shuffles that pluck out a
specific row. Since these diagnostics are highly specific to the pass itself
and not immediately actionable for compiler users, these prints don't make a
whole lot of sense as Remarks.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 47 +++++++++++++++++--
 1 file changed, 43 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 5a518244a80ca..143a3006c18f0 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
@@ -221,7 +222,16 @@ struct ShapeInfo {
 
   /// Returns the transposed shape.
   ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
+
+  friend raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI);
+
+  LLVM_DUMP_METHOD void dump() const { dbgs() << *this << '\n'; }
 };
+
+raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
+  return OS << SI.NumRows << 'x' << SI.NumColumns;
+}
+
 } // namespace
 
 static bool isUniformShape(Value *V) {
@@ -466,6 +476,8 @@ class LowerMatrixIntrinsics {
       return getNumColumns();
     }
 
+    ShapeInfo shape() const { return {getNumRows(), getNumColumns()}; }
+
     /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
     /// matrix is column-major, the result vector is extracted from a column
     /// vector, otherwise from a row vector.
@@ -578,6 +590,25 @@ class LowerMatrixIntrinsics {
       SplitVecs.push_back(V);
     }
 
+    LLVM_DEBUG(if (Instruction *Inst = dyn_cast<Instruction>(MatrixVal)) {
+      if (Found != Inst2ColumnMatrix.end()) {
+        // FIXME: re: "at least": SplitVecs.size() doesn't count the shuffles
+        // that embedInVector created.
+        dbgs() << "matrix reshape from " << Found->second.shape() << " to "
+               << SI << " using at least " << SplitVecs.size()
+               << " shuffles on behalf of " << *Inst << '\n';
+      } else if (!ShapeMap.contains(MatrixVal)) {
+        dbgs() << "splitting a " << SI << " matrix with " << SplitVecs.size()
+               << " shuffles beacuse we do not have a shape-aware lowering for "
+                  "its def: "
+               << *Inst << '\n';
+      } else {
+        // The ShapeMap has it, so it's a case where we're being lowered
+        // before the def, and we expect that InstCombine will clean things up
+        // afterward.
+      }
+    });
+
     return {SplitVecs};
   }
 
@@ -1386,11 +1417,19 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      if (!Flattened) {
+        Flattened = Matrix.embedInVector(Builder);
+        LLVM_DEBUG(
+            if (Instruction *User = dyn_cast<Instruction>(U.getUser())) dbgs()
+                << "flattening a " << Matrix.shape() << " matrix " << *Inst
+                << " because we do not have a shape-aware lowering for its "
+                   "user: "
+                << *User << '\n';);
       }
+      U.set(Flattened);
     }
   }
 



More information about the llvm-commits mailing list