[llvm] [Matrix] Propagate shape information through Select insts (PR #141876)

Jon Roelofs via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 11:31:26 PDT 2025


https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/141876

>From 4040d3fc777ff8d5b212e77fac604f60d997475a Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 16:16:50 -0700
Subject: [PATCH 1/5] [Matrix] Propagate shape information through Select insts

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 49 ++++++++++++-
 .../LowerMatrixIntrinsics/select.ll           | 68 +++++++++++++++++++
 2 files changed, 116 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/select.ll

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..6c364f057481a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
+  if (isa<SelectInst>(I)) {
+    auto OpShape = ShapeMap.find(I->getOperand(1));
+    if (OpShape != ShapeMap.end())
+      return OpShape->second;
+    OpShape = ShapeMap.find(I->getOperand(2));
+    if (OpShape != ShapeMap.end())
+      return OpShape->second;
+  }
+
   if (isUniformShape(I)) {
     // Find the first operand that has a known shape and use that.
     for (auto &Op : I->operands()) {
@@ -623,7 +632,8 @@ class LowerMatrixIntrinsics {
       default:
         return false;
       }
-    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+    return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+           isa<SelectInst>(V);
   }
 
   /// Propagate the shape information of instructions to their users.
@@ -710,6 +720,12 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
+      } else if (auto *Select = dyn_cast<SelectInst>(V)) {
+        ShapeInfo Shape = ShapeMap[V];
+        if (setShapeInfo(Select->getOperand(1), Shape))
+          pushInstruction(Select, WorkList);
+        if (setShapeInfo(Select->getOperand(2), Shape))
+          pushInstruction(Select, WorkList);
       } else if (isUniformShape(V)) {
         // Propagate to all operands.
         ShapeInfo Shape = ShapeMap[V];
@@ -1068,6 +1084,8 @@ class LowerMatrixIntrinsics {
         Changed |= VisitBinaryOperator(BinOp);
       if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
         Changed |= VisitUnaryOperator(UnOp);
+      if (auto *Select = dyn_cast<SelectInst>(Inst))
+        Changed |= VisitSelectInst(Select);
       if (match(Inst, m_Load(m_Value(Op1))))
         Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2198,6 +2216,35 @@ class LowerMatrixIntrinsics {
     return true;
   }
 
+  /// Lower selects, if shape information is available.
+  bool VisitSelectInst(SelectInst *Inst) {
+    auto I = ShapeMap.find(Inst);
+    if (I == ShapeMap.end())
+      return false;
+
+    Value *Cond = Inst->getOperand(0);
+    Value *OpA = Inst->getOperand(1);
+    Value *OpB = Inst->getOperand(2);
+
+    IRBuilder<> Builder(Inst);
+    ShapeInfo &Shape = I->second;
+
+    MatrixTy Result;
+    MatrixTy A = getMatrix(OpA, Shape, Builder);
+    MatrixTy B = getMatrix(OpB, Shape, Builder);
+
+    for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
+      auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
+      Result.addVector(Sel);
+    }
+
+    finalizeLowering(Inst,
+                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+                                             Result.getNumVectors()),
+                     Builder);
+    return true;
+  }
+
   /// Helper to linearize a matrix expression tree into a string. Currently
   /// matrix expressions are linarized by starting at an expression leaf and
   /// linearizing bottom up.
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
new file mode 100644
index 0000000000000..507b02a04f47f
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_bot(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
+  ret void
+}
+
+define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_lhs(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2)
+  %rhsv = load <4 x float>, ptr %rhs
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_rhs(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}

>From c0c63f392205e42a1421c475c8d120d49ba9bf1d Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:46:58 -0700
Subject: [PATCH 2/5] select with mismatched shape

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 14 ++--
 .../LowerMatrixIntrinsics/select.ll           | 66 +++++++++++++------
 2 files changed, 56 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index ef57faf911e31..7da55a2a9a355 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -271,7 +271,9 @@ computeShapeInfoForInst(Instruction *I,
   }
 
   if (auto *Select = dyn_cast<SelectInst>(I)) {
-    for (Use &Op : Select->getCondition()->getType()->isVectorTy() ? I->operands() : drop_begin(I->operands())) {
+    Type *CondTy = Select->getCondition()->getType();
+    for (Use &Op : CondTy->isVectorTy() ? Select->operands()
+                                        : drop_begin(Select->operands())) {
       auto OpShape = ShapeMap.find(Op);
       if (OpShape != ShapeMap.end())
         return OpShape->second;
@@ -719,10 +721,12 @@ class LowerMatrixIntrinsics {
         // backward propagate to an instruction with an already known shape.
       } else if (auto *Select = dyn_cast<SelectInst>(V)) {
         ShapeInfo Shape = ShapeMap[V];
-        if (setShapeInfo(Select->getOperand(1), Shape))
-          pushInstruction(Select, WorkList);
-        if (setShapeInfo(Select->getOperand(2), Shape))
-          pushInstruction(Select, WorkList);
+        Type *CondTy = Select->getCondition()->getType();
+        for (Use &Op : CondTy->isVectorTy() ? Select->operands()
+                                            : drop_begin(Select->operands())) {
+          if (setShapeInfo(Op, Shape))
+            pushInstruction(Select, WorkList);
+        }
       } else if (isUniformShape(V)) {
         // Propagate to all operands.
         ShapeInfo Shape = ShapeMap[V];
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
index 31c34e24c540d..56dca7bb985d3 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -67,39 +67,41 @@ define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
   ret void
 }
 
-define void @select_2x2_vcond(<4 x i1> %cond, ptr %lhs, ptr %rhs, ptr %out) {
-; CHECK-LABEL: @select_2x2_vcond(
+define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape1(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i1> [[COND:%.*]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT:    [[SPLIT5:%.*]] = shufflevector <4 x i1> [[COND]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
-; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[SPLIT5]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
 ; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
-; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[OUT]], i64 2
-; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP6]], align 8
+; CHECK-NEXT:    [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %lhsv = load <4 x float>, ptr %lhs
+  %condv = load <4 x i1>, ptr %cond
   %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
-  %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
   store <4 x float> %op, ptr %out
   ret void
 }
 
-define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
-; CHECK-LABEL: @select_2x2_vcond_shape(
+define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape2(
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
-; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS]], align 4
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
@@ -110,9 +112,35 @@ define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
 ; CHECK-NEXT:    ret void
 ;
   %lhsv = load <4 x float>, ptr %lhs
-  %cond = call <4 x i1> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2)
+  %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
+  store <4 x float> %op, ptr %out
+  ret void
+}
+
+define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape3(
+; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT:    [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
+; CHECK-NEXT:    [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4
+; CHECK-NEXT:    [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT:    [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT:    store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT:    ret void
+;
+  %lhsv = load <4 x float>, ptr %lhs
+  %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
   %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
-  %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
+  %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
   store <4 x float> %op, ptr %out
   ret void
 }

>From 5a4a6a507f57da9fc4d4081385a9f8339929da87 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:48:15 -0700
Subject: [PATCH 3/5] no return value

---
 llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 7da55a2a9a355..deff2908b4902 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1087,7 +1087,7 @@ class LowerMatrixIntrinsics {
       else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
         VisitCallInst(CInst);
       else if (auto *Select = dyn_cast<SelectInst>(Inst))
-        Changed |= VisitSelectInst(Select, SI);
+        VisitSelectInst(Select, SI);
       else if (match(Inst, m_Load(m_Value(Op1))))
         VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
       else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2200,7 +2200,7 @@ class LowerMatrixIntrinsics {
   }
 
   /// Lower selects.
-  bool VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
+  void VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
     Value *Cond = Inst->getOperand(0);
     Value *OpA = Inst->getOperand(1);
     Value *OpB = Inst->getOperand(2);
@@ -2228,7 +2228,6 @@ class LowerMatrixIntrinsics {
                      Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
                                              Result.getNumVectors()),
                      Builder);
-    return true;
   }
 
   /// Helper to linearize a matrix expression tree into a string. Currently

>From 6693f56f75e8d67524e1b2cb8cbf6455d2fcf17d Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:52:54 -0700
Subject: [PATCH 4/5] clean up condition lookup

---
 .../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 17 +++++++++--------
 1 file changed, 9 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index deff2908b4902..3b4f1bf7ff67d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -2211,19 +2211,20 @@ class LowerMatrixIntrinsics {
     MatrixTy A = getMatrix(OpA, Shape, Builder);
     MatrixTy B = getMatrix(OpB, Shape, Builder);
 
+    Value *CondV[2];
     if (isa<FixedVectorType>(Cond->getType())) {
       MatrixTy C = getMatrix(Cond, Shape, Builder);
-      for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
-        auto *Sel = Builder.CreateSelect(C.getVector(I), A.getVector(I), B.getVector(I));
-        Result.addVector(Sel);
-      }
+      CondV[0] = C.getVector(0);
+      CondV[1] = C.getVector(1);
     } else {
-      for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
-        auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
-        Result.addVector(Sel);
-      }
+      CondV[0] = Cond;
+      CondV[1] = Cond;
     }
 
+    for (unsigned I = 0, E = Shape.getNumVectors(); I != E; ++I)
+      Result.addVector(
+          Builder.CreateSelect(CondV[I], A.getVector(I), B.getVector(I)));
+
     finalizeLowering(Inst,
                      Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
                                              Result.getNumVectors()),

>From 1a4faa4c9e43aa47319ea9d917b65a72f2bb67ce Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 11:16:14 -0700
Subject: [PATCH 5/5] review feedback: don't take shape info from select
 conditions

---
 .../Scalar/LowerMatrixIntrinsics.cpp          | 30 +++++--------------
 .../LowerMatrixIntrinsics/select.ll           |  6 ++--
 2 files changed, 11 insertions(+), 25 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3b4f1bf7ff67d..3a5b0f8fdb415 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -270,19 +270,11 @@ computeShapeInfoForInst(Instruction *I,
       return OpShape->second;
   }
 
-  if (auto *Select = dyn_cast<SelectInst>(I)) {
-    Type *CondTy = Select->getCondition()->getType();
-    for (Use &Op : CondTy->isVectorTy() ? Select->operands()
-                                        : drop_begin(Select->operands())) {
-      auto OpShape = ShapeMap.find(Op);
-      if (OpShape != ShapeMap.end())
-        return OpShape->second;
-    }
-  }
-
-  if (isUniformShape(I)) {
+  if (isUniformShape(I) || isa<SelectInst>(I)) {
+    auto Ops = I->operands();
+    auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
     // Find the first operand that has a known shape and use that.
-    for (auto &Op : I->operands()) {
+    for (auto &Op : ShapedOps) {
       auto OpShape = ShapeMap.find(Op.get());
       if (OpShape != ShapeMap.end())
         return OpShape->second;
@@ -719,18 +711,12 @@ class LowerMatrixIntrinsics {
       } else if (isa<StoreInst>(V)) {
         // Nothing to do.  We forward-propagated to this so we would just
         // backward propagate to an instruction with an already known shape.
-      } else if (auto *Select = dyn_cast<SelectInst>(V)) {
-        ShapeInfo Shape = ShapeMap[V];
-        Type *CondTy = Select->getCondition()->getType();
-        for (Use &Op : CondTy->isVectorTy() ? Select->operands()
-                                            : drop_begin(Select->operands())) {
-          if (setShapeInfo(Op, Shape))
-            pushInstruction(Select, WorkList);
-        }
-      } else if (isUniformShape(V)) {
+      } else if (isUniformShape(V) || isa<SelectInst>(V)) {
+        auto Ops = cast<Instruction>(V)->operands();
+        auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
         // Propagate to all operands.
         ShapeInfo Shape = ShapeMap[V];
-        for (Use &U : cast<Instruction>(V)->operands()) {
+        for (Use &U : ShapedOps) {
           if (setShapeInfo(U.get(), Shape))
             pushInstruction(U.get(), WorkList);
         }
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
index 56dca7bb985d3..70b0dfdb3e7e8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -72,12 +72,12 @@ define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
 ; CHECK-NEXT:    [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
 ; CHECK-NEXT:    [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT:    [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
-; CHECK-NEXT:    [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
-; CHECK-NEXT:    [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
+; CHECK-NEXT:    [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
 ; CHECK-NEXT:    [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
 ; CHECK-NEXT:    [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
 ; CHECK-NEXT:    [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT:    [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT:    [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
 ; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
 ; CHECK-NEXT:    store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16



More information about the llvm-commits mailing list