[llvm] [Matrix] Add a Remark when matrices get flattened (PR #142078)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 19:47:53 PDT 2025


https://github.com/jroelofs created https://github.com/llvm/llvm-project/pull/142078

This is a potential source of overhead, which we might be able to alleviate in some cases. For example, static element extracts, or shuffles that pluck out a specific row.

>From 232fe4ea704ffff3a66a6e4bde4622a0600e6edf Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 29 May 2025 19:45:04 -0700
Subject: [PATCH] [Matrix] Add a Remark when matrices get flattened

This is a potential source of overhead, which we might be able to alleviate in
some cases. For example, static element extracts, or shuffles that pluck out a
specific row.
---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 25 +++++++++++---
 .../LowerMatrixIntrinsics/flatten.ll          | 34 +++++++++++++++++++
 2 files changed, 55 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..c2111b6223b22 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1351,11 +1351,28 @@ class LowerMatrixIntrinsics {
     ToRemove.push_back(Inst);
     Value *Flattened = nullptr;
     for (Use &U : llvm::make_early_inc_range(Inst->uses())) {
-      if (!ShapeMap.contains(U.getUser())) {
-        if (!Flattened)
-          Flattened = Matrix.embedInVector(Builder);
-        U.set(Flattened);
+      if (ShapeMap.contains(U.getUser()))
+        continue;
+
+      if (!Flattened) {
+        Flattened = Matrix.embedInVector(Builder);
+        if (ORE) {
+          if (Instruction *User = dyn_cast<Instruction>(U.getUser())) {
+            std::string Str;
+            llvm::raw_string_ostream OS(Str);
+            OS << *User;
+            ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, "unknown-shape-lowering", User)
+              << "flattening a "
+              << ore::NV("Rows", Matrix.getNumRows()) << "x"
+              << ore::NV("Cols", Matrix.getNumColumns())
+              << " matrix because we do not have a shape-aware lowering for its user:"
+              << ore::NV("Instr", OS.str())
+              << ore::setExtraArgs()
+              << ore::NV("Opcode", User->getOpcodeName()));
+          }
+        }
       }
+      U.set(Flattened);
     }
   }
 
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll
new file mode 100644
index 0000000000000..4ff8211b040c4
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/flatten.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+; RUN: opt -passes=lower-matrix-intrinsics -pass-remarks-missed=lower-matrix-intrinsics < %s -pass-remarks-output=%t -disable-output && FileCheck --input-file %t %s --check-prefix=REMARK
+
+define void @diag_3x3(ptr %in, ptr %out) {
+; REMARK-LABEL: Name:            unknown-shape-lowering
+; REMARK-NEXT: Function:        diag_3x3
+; REMARK-NEXT: Args:
+; REMARK-NEXT:   - String:          'flattening a '
+; REMARK-NEXT:   - Rows:            '3'
+; REMARK-NEXT:   - String:          x
+; REMARK-NEXT:   - Cols:            '3'
+; REMARK-NEXT:   - String:          ' matrix because we do not have a shape-aware lowering for its user:'
+; REMARK-NEXT:   - Instr:           '  %diag = shufflevector <9 x float> %inv, <9 x float> poison, <3 x i32> <i32 0, i32 4, i32 8>'
+; REMARK-NEXT:   - Opcode:          shufflevector
+; REMARK-NEXT: ...
+; CHECK-LABEL: @diag_3x3(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <3 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[IN]], i64 3
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <3 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[IN]], i64 6
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <3 x float>, ptr [[VEC_GEP2]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x float> [[COL_LOAD]], <3 x float> [[COL_LOAD1]], <6 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <3 x float> [[COL_LOAD3]], <3 x float> poison, <6 x i32> <i32 0, i32 1, i32 2, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <6 x float> [[TMP1]], <6 x float> [[TMP2]], <9 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8>
+; CHECK-NEXT:    [[DIAG:%.*]] = shufflevector <9 x float> [[TMP3]], <9 x float> poison, <3 x i32> <i32 0, i32 4, i32 8>
+; CHECK-NEXT:    store <3 x float> [[DIAG]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    ret void
+;
+  %inv = call <9 x float> @llvm.matrix.column.major.load(ptr %in, i64 3, i1 false, i32 3, i32 3)
+  %diag = shufflevector <9 x float> %inv, <9 x float> poison, <3 x i32> <i32 0, i32 4, i32 8>
+  store <3 x float> %diag, ptr %out
+  ret void
+}



More information about the llvm-commits mailing list