[llvm] [Matrix] (NFC) Refactor sharing of shape information (PR #164774)

Nathan Corbyn via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 23 01:29:04 PDT 2025


https://github.com/cofibrant created https://github.com/llvm/llvm-project/pull/164774

This patch removes some duplicated logic in the `LowerMatrixIntrinsics` pass used to identify when the matrix shapes of some operands are preserved by an instruction's result and which operands these are.

>From e59f655975af37e6b92138f33ea4923c07703208 Mon Sep 17 00:00:00 2001
From: Nathan Corbyn <n_corbyn at apple.com>
Date: Thu, 23 Oct 2025 09:16:22 +0100
Subject: [PATCH] [Matrix] (NFC) Refactor sharing of shape information

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 30 ++++++++++++-------
 1 file changed, 20 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3487e812a68a3..7e70ba274f161 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) {
 
 } // namespace
 
-static bool isUniformShape(Value *V) {
+static bool isShapePreserving(Value *V) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I)
     return true;
 
+  if (isa<SelectInst>(I))
+    return true;
+
   if (I->isBinaryOp())
     return true;
 
@@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) {
   }
 }
 
+/// Return an iterator over the operands of \p I that should share shape
+/// information with \p I.
+static iterator_range<Use *> getShapedOperandsForInst(Instruction *I) {
+  assert(isShapePreserving(I) &&
+         "Can't retrieve shaped operands for an instruction that does not "
+         "preserve shape information");
+  auto Ops = I->operands();
+  return isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+}
+
 /// Return the ShapeInfo for the result of \p I, it it can be determined.
 static std::optional<ShapeInfo>
 computeShapeInfoForInst(Instruction *I,
@@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
-  if (isUniformShape(I) || isa<SelectInst>(I)) {
-    auto Ops = I->operands();
-    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
+  if (isShapePreserving(I)) {
+    auto ShapedOps = getShapedOperandsForInst(I);
     // Find the first operand that has a known shape and use that.
     for (auto &Op : ShapedOps) {
       auto OpShape = ShapeMap.find(Op.get());
@@ -710,10 +722,9 @@ class LowerMatrixIntrinsics {
       case Intrinsic::matrix_column_major_store:
         return true;
       default:
-        return isUniformShape(II);
+        break;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
-           isa<SelectInst>(V);
+    return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -800,9 +811,8 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
-      } else if (isUniformShape(V) || isa<SelectInst>(V)) {
-        auto Ops = cast<Instruction>(V)->operands();
-        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
+      } else if (isShapePreserving(V)) {
+        auto ShapedOps = getShapedOperandsForInst(cast<Instruction>(V));
         // Propagate to all operands.
         ShapeInfo Shape = ShapeMap[V];
         for (Use &U : ShapedOps) {



More information about the llvm-commits mailing list