[llvm] [Matrix] Propagate shape information through Select insts (PR #141876)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 11:31:26 PDT 2025
https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/141876
>From 4040d3fc777ff8d5b212e77fac604f60d997475a Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Wed, 28 May 2025 16:16:50 -0700
Subject: [PATCH 1/5] [Matrix] Propagate shape information through Select insts
---
.../Scalar/LowerMatrixIntrinsics.cpp | 49 ++++++++++++-
.../LowerMatrixIntrinsics/select.ll | 68 +++++++++++++++++++
2 files changed, 116 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 756a72e6d97bc..6c364f057481a 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -269,6 +269,15 @@ computeShapeInfoForInst(Instruction *I,
return OpShape->second;
}
+ if (isa<SelectInst>(I)) {
+ auto OpShape = ShapeMap.find(I->getOperand(1));
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ OpShape = ShapeMap.find(I->getOperand(2));
+ if (OpShape != ShapeMap.end())
+ return OpShape->second;
+ }
+
if (isUniformShape(I)) {
// Find the first operand that has a known shape and use that.
for (auto &Op : I->operands()) {
@@ -623,7 +632,8 @@ class LowerMatrixIntrinsics {
default:
return false;
}
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+ return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) ||
+ isa<SelectInst>(V);
}
/// Propagate the shape information of instructions to their users.
@@ -710,6 +720,12 @@ class LowerMatrixIntrinsics {
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
// backward propagate to an instruction with an already known shape.
+ } else if (auto *Select = dyn_cast<SelectInst>(V)) {
+ ShapeInfo Shape = ShapeMap[V];
+ if (setShapeInfo(Select->getOperand(1), Shape))
+ pushInstruction(Select, WorkList);
+ if (setShapeInfo(Select->getOperand(2), Shape))
+ pushInstruction(Select, WorkList);
} else if (isUniformShape(V)) {
// Propagate to all operands.
ShapeInfo Shape = ShapeMap[V];
@@ -1068,6 +1084,8 @@ class LowerMatrixIntrinsics {
Changed |= VisitBinaryOperator(BinOp);
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
Changed |= VisitUnaryOperator(UnOp);
+ if (auto *Select = dyn_cast<SelectInst>(Inst))
+ Changed |= VisitSelectInst(Select);
if (match(Inst, m_Load(m_Value(Op1))))
Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2198,6 +2216,35 @@ class LowerMatrixIntrinsics {
return true;
}
+ /// Lower selects, if shape information is available.
+ bool VisitSelectInst(SelectInst *Inst) {
+ auto I = ShapeMap.find(Inst);
+ if (I == ShapeMap.end())
+ return false;
+
+ Value *Cond = Inst->getOperand(0);
+ Value *OpA = Inst->getOperand(1);
+ Value *OpB = Inst->getOperand(2);
+
+ IRBuilder<> Builder(Inst);
+ ShapeInfo &Shape = I->second;
+
+ MatrixTy Result;
+ MatrixTy A = getMatrix(OpA, Shape, Builder);
+ MatrixTy B = getMatrix(OpB, Shape, Builder);
+
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
+ auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
+ Result.addVector(Sel);
+ }
+
+ finalizeLowering(Inst,
+ Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors()),
+ Builder);
+ return true;
+ }
+
/// Helper to linearize a matrix expression tree into a string. Currently
/// matrix expressions are linarized by starting at an expression leaf and
/// linearizing bottom up.
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
new file mode 100644
index 0000000000000..507b02a04f47f
--- /dev/null
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes='lower-matrix-intrinsics' -S < %s | FileCheck %s
+
+define void @select_2x2_bot(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_bot(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 4
+; CHECK-NEXT: ret void
+;
+ %lhsv = load <4 x float>, ptr %lhs
+ %rhsv = load <4 x float>, ptr %rhs
+ %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+ call void @llvm.matrix.column.major.store(<4 x float> %op, ptr %out, i64 2, i1 false, i32 2, i32 2)
+ ret void
+}
+
+define void @select_2x2_lhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_lhs(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 4
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 8
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP5:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP5]], align 8
+; CHECK-NEXT: ret void
+;
+ %lhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %lhs, i64 2, i1 false, i32 2, i32 2)
+ %rhsv = load <4 x float>, ptr %rhs
+ %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+ store <4 x float> %op, ptr %out
+ ret void
+}
+
+define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_rhs(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[COND:%.*]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
+; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP2:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP2]], align 8
+; CHECK-NEXT: ret void
+;
+ %lhsv = load <4 x float>, ptr %lhs
+ %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+ %op = select i1 %cond, <4 x float> %lhsv, <4 x float> %rhsv
+ store <4 x float> %op, ptr %out
+ ret void
+}
>From c0c63f392205e42a1421c475c8d120d49ba9bf1d Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:46:58 -0700
Subject: [PATCH 2/5] select with mismatched shape
---
.../Scalar/LowerMatrixIntrinsics.cpp | 14 ++--
.../LowerMatrixIntrinsics/select.ll | 66 +++++++++++++------
2 files changed, 56 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index ef57faf911e31..7da55a2a9a355 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -271,7 +271,9 @@ computeShapeInfoForInst(Instruction *I,
}
if (auto *Select = dyn_cast<SelectInst>(I)) {
- for (Use &Op : Select->getCondition()->getType()->isVectorTy() ? I->operands() : drop_begin(I->operands())) {
+ Type *CondTy = Select->getCondition()->getType();
+ for (Use &Op : CondTy->isVectorTy() ? Select->operands()
+ : drop_begin(Select->operands())) {
auto OpShape = ShapeMap.find(Op);
if (OpShape != ShapeMap.end())
return OpShape->second;
@@ -719,10 +721,12 @@ class LowerMatrixIntrinsics {
// backward propagate to an instruction with an already known shape.
} else if (auto *Select = dyn_cast<SelectInst>(V)) {
ShapeInfo Shape = ShapeMap[V];
- if (setShapeInfo(Select->getOperand(1), Shape))
- pushInstruction(Select, WorkList);
- if (setShapeInfo(Select->getOperand(2), Shape))
- pushInstruction(Select, WorkList);
+ Type *CondTy = Select->getCondition()->getType();
+ for (Use &Op : CondTy->isVectorTy() ? Select->operands()
+ : drop_begin(Select->operands())) {
+ if (setShapeInfo(Op, Shape))
+ pushInstruction(Select, WorkList);
+ }
} else if (isUniformShape(V)) {
// Propagate to all operands.
ShapeInfo Shape = ShapeMap[V];
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
index 31c34e24c540d..56dca7bb985d3 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -67,39 +67,41 @@ define void @select_2x2_rhs(i1 %cond, ptr %lhs, ptr %rhs, ptr %out) {
ret void
}
-define void @select_2x2_vcond(<4 x i1> %cond, ptr %lhs, ptr %rhs, ptr %out) {
-; CHECK-LABEL: @select_2x2_vcond(
+define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape1(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
-; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr float, ptr [[RHS]], i64 2
-; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x float>, ptr [[VEC_GEP3]], align 4
-; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COND:%.*]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
-; CHECK-NEXT: [[SPLIT5:%.*]] = shufflevector <4 x i1> [[COND]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
-; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD2]]
-; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT5]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD4]]
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
+; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
-; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[OUT]], i64 2
-; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP6]], align 8
+; CHECK-NEXT: [[VEC_GEP8:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP8]], align 8
; CHECK-NEXT: ret void
;
%lhsv = load <4 x float>, ptr %lhs
+ %condv = load <4 x i1>, ptr %cond
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
- %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
+ %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
store <4 x float> %op, ptr %out
ret void
}
-define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
-; CHECK-LABEL: @select_2x2_vcond_shape(
+define void @select_2x2_vcond_shape2(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape2(
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
-; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[COND]], i64 2
; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
-; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS]], align 4
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS]], i64 2
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
@@ -110,9 +112,35 @@ define void @select_2x2_vcond_shape(ptr %lhs, ptr %rhs, ptr %out) {
; CHECK-NEXT: ret void
;
%lhsv = load <4 x float>, ptr %lhs
- %cond = call <4 x i1> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+ %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 2, i1 false, i32 2, i32 2)
+ %rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
+ %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
+ store <4 x float> %op, ptr %out
+ ret void
+}
+
+define void @select_2x2_vcond_shape3(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
+; CHECK-LABEL: @select_2x2_vcond_shape3(
+; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
+; CHECK-NEXT: [[COL_LOAD3:%.*]] = load <2 x float>, ptr [[RHS:%.*]], align 4
+; CHECK-NEXT: [[VEC_GEP4:%.*]] = getelementptr float, ptr [[RHS]], i64 2
+; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[VEC_GEP4]], align 4
+; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[SPLIT6:%.*]] = shufflevector <4 x i1> [[COL_LOAD2]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
+; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[SPLIT]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD3]]
+; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[SPLIT6]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD5]]
+; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
+; CHECK-NEXT: [[VEC_GEP7:%.*]] = getelementptr float, ptr [[OUT]], i64 2
+; CHECK-NEXT: store <2 x float> [[TMP2]], ptr [[VEC_GEP7]], align 8
+; CHECK-NEXT: ret void
+;
+ %lhsv = load <4 x float>, ptr %lhs
+ %condv = call <4 x i1> @llvm.matrix.column.major.load(ptr %cond, i64 4, i1 false, i32 4, i32 1)
%rhsv = call <4 x float> @llvm.matrix.column.major.load(ptr %rhs, i64 2, i1 false, i32 2, i32 2)
- %op = select <4 x i1> %cond, <4 x float> %lhsv, <4 x float> %rhsv
+ %op = select <4 x i1> %condv, <4 x float> %lhsv, <4 x float> %rhsv
store <4 x float> %op, ptr %out
ret void
}
>From 5a4a6a507f57da9fc4d4081385a9f8339929da87 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:48:15 -0700
Subject: [PATCH 3/5] no return value
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 7da55a2a9a355..deff2908b4902 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1087,7 +1087,7 @@ class LowerMatrixIntrinsics {
else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
VisitCallInst(CInst);
else if (auto *Select = dyn_cast<SelectInst>(Inst))
- Changed |= VisitSelectInst(Select, SI);
+ VisitSelectInst(Select, SI);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -2200,7 +2200,7 @@ class LowerMatrixIntrinsics {
}
/// Lower selects.
- bool VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
+ void VisitSelectInst(SelectInst *Inst, const ShapeInfo &Shape) {
Value *Cond = Inst->getOperand(0);
Value *OpA = Inst->getOperand(1);
Value *OpB = Inst->getOperand(2);
@@ -2228,7 +2228,6 @@ class LowerMatrixIntrinsics {
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Helper to linearize a matrix expression tree into a string. Currently
>From 6693f56f75e8d67524e1b2cb8cbf6455d2fcf17d Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 13:52:54 -0700
Subject: [PATCH 4/5] clean up condition lookup
---
.../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 17 +++++++++--------
1 file changed, 9 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index deff2908b4902..3b4f1bf7ff67d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -2211,19 +2211,20 @@ class LowerMatrixIntrinsics {
MatrixTy A = getMatrix(OpA, Shape, Builder);
MatrixTy B = getMatrix(OpB, Shape, Builder);
+ Value *CondV[2];
if (isa<FixedVectorType>(Cond->getType())) {
MatrixTy C = getMatrix(Cond, Shape, Builder);
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
- auto *Sel = Builder.CreateSelect(C.getVector(I), A.getVector(I), B.getVector(I));
- Result.addVector(Sel);
- }
+ CondV[0] = C.getVector(0);
+ CondV[1] = C.getVector(1);
} else {
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I) {
- auto *Sel = Builder.CreateSelect(Cond, A.getVector(I), B.getVector(I));
- Result.addVector(Sel);
- }
+ CondV[0] = Cond;
+ CondV[1] = Cond;
}
+ for (unsigned I = 0, E = Shape.getNumVectors(); I != E; ++I)
+ Result.addVector(
+ Builder.CreateSelect(CondV[I], A.getVector(I), B.getVector(I)));
+
finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
>From 1a4faa4c9e43aa47319ea9d917b65a72f2bb67ce Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Tue, 10 Jun 2025 11:16:14 -0700
Subject: [PATCH 5/5] review feedback: don't take shape info from select
conditions
---
.../Scalar/LowerMatrixIntrinsics.cpp | 30 +++++--------------
.../LowerMatrixIntrinsics/select.ll | 6 ++--
2 files changed, 11 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3b4f1bf7ff67d..3a5b0f8fdb415 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -270,19 +270,11 @@ computeShapeInfoForInst(Instruction *I,
return OpShape->second;
}
- if (auto *Select = dyn_cast<SelectInst>(I)) {
- Type *CondTy = Select->getCondition()->getType();
- for (Use &Op : CondTy->isVectorTy() ? Select->operands()
- : drop_begin(Select->operands())) {
- auto OpShape = ShapeMap.find(Op);
- if (OpShape != ShapeMap.end())
- return OpShape->second;
- }
- }
-
- if (isUniformShape(I)) {
+ if (isUniformShape(I) || isa<SelectInst>(I)) {
+ auto Ops = I->operands();
+ auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops;
// Find the first operand that has a known shape and use that.
- for (auto &Op : I->operands()) {
+ for (auto &Op : ShapedOps) {
auto OpShape = ShapeMap.find(Op.get());
if (OpShape != ShapeMap.end())
return OpShape->second;
@@ -719,18 +711,12 @@ class LowerMatrixIntrinsics {
} else if (isa<StoreInst>(V)) {
// Nothing to do. We forward-propagated to this so we would just
// backward propagate to an instruction with an already known shape.
- } else if (auto *Select = dyn_cast<SelectInst>(V)) {
- ShapeInfo Shape = ShapeMap[V];
- Type *CondTy = Select->getCondition()->getType();
- for (Use &Op : CondTy->isVectorTy() ? Select->operands()
- : drop_begin(Select->operands())) {
- if (setShapeInfo(Op, Shape))
- pushInstruction(Select, WorkList);
- }
- } else if (isUniformShape(V)) {
+ } else if (isUniformShape(V) || isa<SelectInst>(V)) {
+ auto Ops = cast<Instruction>(V)->operands();
+ auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops;
// Propagate to all operands.
ShapeInfo Shape = ShapeMap[V];
- for (Use &U : cast<Instruction>(V)->operands()) {
+ for (Use &U : ShapedOps) {
if (setShapeInfo(U.get(), Shape))
pushInstruction(U.get(), WorkList);
}
diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
index 56dca7bb985d3..70b0dfdb3e7e8 100644
--- a/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
+++ b/llvm/test/Transforms/LowerMatrixIntrinsics/select.ll
@@ -72,12 +72,12 @@ define void @select_2x2_vcond_shape1(ptr %cond, ptr %lhs, ptr %rhs, ptr %out) {
; CHECK-NEXT: [[COL_LOAD:%.*]] = load <2 x float>, ptr [[LHS:%.*]], align 16
; CHECK-NEXT: [[VEC_GEP:%.*]] = getelementptr float, ptr [[LHS]], i64 2
; CHECK-NEXT: [[COL_LOAD1:%.*]] = load <2 x float>, ptr [[VEC_GEP]], align 8
-; CHECK-NEXT: [[COL_LOAD2:%.*]] = load <2 x i1>, ptr [[RHS:%.*]], align 1
-; CHECK-NEXT: [[VEC_GEP3:%.*]] = getelementptr i1, ptr [[RHS]], i64 2
-; CHECK-NEXT: [[COL_LOAD4:%.*]] = load <2 x i1>, ptr [[VEC_GEP3]], align 1
+; CHECK-NEXT: [[CONDV:%.*]] = load <4 x i1>, ptr [[COND:%.*]], align 1
; CHECK-NEXT: [[COL_LOAD5:%.*]] = load <2 x float>, ptr [[RHS1:%.*]], align 4
; CHECK-NEXT: [[VEC_GEP6:%.*]] = getelementptr float, ptr [[RHS1]], i64 2
; CHECK-NEXT: [[COL_LOAD7:%.*]] = load <2 x float>, ptr [[VEC_GEP6]], align 4
+; CHECK-NEXT: [[COL_LOAD2:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 0, i32 1>
+; CHECK-NEXT: [[COL_LOAD4:%.*]] = shufflevector <4 x i1> [[CONDV]], <4 x i1> poison, <2 x i32> <i32 2, i32 3>
; CHECK-NEXT: [[TMP1:%.*]] = select <2 x i1> [[COL_LOAD2]], <2 x float> [[COL_LOAD]], <2 x float> [[COL_LOAD5]]
; CHECK-NEXT: [[TMP2:%.*]] = select <2 x i1> [[COL_LOAD4]], <2 x float> [[COL_LOAD1]], <2 x float> [[COL_LOAD7]]
; CHECK-NEXT: store <2 x float> [[TMP1]], ptr [[OUT:%.*]], align 16
More information about the llvm-commits
mailing list