[llvm] [Matrix] Don't update Changed based on Visit* return value (NFC). (PR #142487)
Jon Roelofs via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 5 09:44:24 PDT 2025
https://github.com/jroelofs updated https://github.com/llvm/llvm-project/pull/142487
>From afe1fb05426a62915dfa72354878bcb14ad66d94 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 2 Jun 2025 13:59:44 -0700
Subject: [PATCH 1/6] [Matrix] Don't update Changed based on Visit* return
value (NFC).
Visit* are always modifying the IR, remove the boolean result.
Co-authored by: Florian Hahn <florian_hahn at apple.com>
---
.../Scalar/LowerMatrixIntrinsics.cpp | 67 +++++++------------
1 file changed, 26 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index fb5e081acf7c5..124dc54b1dba8 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1056,19 +1056,24 @@ class LowerMatrixIntrinsics {
IRBuilder<> Builder(Inst);
+ const ShapeInfo &SI = ShapeMap.at(Inst);
+
if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= VisitCallInst(CInst);
+ Changed |= tryVisitCallInst(CInst);
Value *Op1;
Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- Changed |= VisitBinaryOperator(BinOp);
- if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- Changed |= VisitUnaryOperator(UnOp);
if (match(Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ VisitStore(cast<StoreInst>(Inst), SI, Op1, Op2, Builder);
+ else if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+ VisitBinaryOperator(BinOp, SI);
+ else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+ VisitUnaryOperator(UnOp, SI);
+ else
+ continue;
+ Changed = true;
}
if (ORE) {
@@ -1107,7 +1112,7 @@ class LowerMatrixIntrinsics {
}
/// Replace intrinsic calls
- bool VisitCallInst(CallInst *Inst) {
+ bool tryVisitCallInst(CallInst *Inst) {
if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
return false;
@@ -2105,49 +2110,36 @@ class LowerMatrixIntrinsics {
}
/// Lower load instructions, if shape information is available.
- bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
+ void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) {
LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+ SI);
}
- bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
+ void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
- auto I = ShapeMap.find(StoredVal);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(),
+ SI);
}
/// Lower binary operators, if shape information is available.
- bool VisitBinaryOperator(BinaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ void VisitBinaryOperator(BinaryOperator *Inst, const ShapeInfo &SI) {
Value *Lhs = Inst->getOperand(0);
Value *Rhs = Inst->getOperand(1);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy A = getMatrix(Lhs, Shape, Builder);
- MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ MatrixTy A = getMatrix(Lhs, SI, Builder);
+ MatrixTy B = getMatrix(Rhs, SI, Builder);
assert(A.isColumnMajor() == B.isColumnMajor() &&
Result.isColumnMajor() == A.isColumnMajor() &&
"operands must agree on matrix layout");
Builder.setFastMathFlags(getFastMathFlags(Inst));
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(Builder.CreateBinOp(Inst->getOpcode(), A.getVector(I),
B.getVector(I)));
@@ -2155,22 +2147,16 @@ class LowerMatrixIntrinsics {
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Lower unary operators, if shape information is available.
- bool VisitUnaryOperator(UnaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- assert(I != ShapeMap.end() &&
- "must only visit instructions with shape info");
-
+ void VisitUnaryOperator(UnaryOperator *Inst, const ShapeInfo &SI) {
Value *Op = Inst->getOperand(0);
IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
MatrixTy Result;
- MatrixTy M = getMatrix(Op, Shape, Builder);
+ MatrixTy M = getMatrix(Op, SI, Builder);
Builder.setFastMathFlags(getFastMathFlags(Inst));
@@ -2184,14 +2170,13 @@ class LowerMatrixIntrinsics {
}
};
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ for (unsigned I = 0; I < SI.getNumVectors(); ++I)
Result.addVector(BuildVectorOp(M.getVector(I)));
finalizeLowering(Inst,
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Helper to linearize a matrix expression tree into a string. Currently
>From 9f4e20556c05838b55df10ab2589985834102e35 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Mon, 2 Jun 2025 14:03:59 -0700
Subject: [PATCH 2/6] clang-format
---
.../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 15 +++++++--------
1 file changed, 7 insertions(+), 8 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 124dc54b1dba8..439e616254037 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -2110,17 +2110,16 @@ class LowerMatrixIntrinsics {
}
/// Lower load instructions, if shape information is available.
- void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr, IRBuilder<> &Builder) {
- LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(SI.getStride()), Inst->isVolatile(),
- SI);
+ void VisitLoad(LoadInst *Inst, const ShapeInfo &SI, Value *Ptr,
+ IRBuilder<> &Builder) {
+ LowerLoad(Inst, Ptr, Inst->getAlign(), Builder.getInt64(SI.getStride()),
+ Inst->isVolatile(), SI);
}
- void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal, Value *Ptr,
- IRBuilder<> &Builder) {
+ void VisitStore(StoreInst *Inst, const ShapeInfo &SI, Value *StoredVal,
+ Value *Ptr, IRBuilder<> &Builder) {
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(SI.getStride()), Inst->isVolatile(),
- SI);
+ Builder.getInt64(SI.getStride()), Inst->isVolatile(), SI);
}
/// Lower binary operators, if shape information is available.
>From 36ec2f5eefe530064384e9b0f4b22e26c6f833eb Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 2 Jun 2025 16:41:10 +0100
Subject: [PATCH 3/6] [Matrix] Don't update Changed based on Visit* return
value (NFC).
Visit* are always modifying the IR, remove the boolean result.
Depends on https://github.com/llvm/llvm-project/pull/142416.
---
.../Scalar/LowerMatrixIntrinsics.cpp | 23 +++++++++----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 1c3d22924795b..3158332b9c993 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1062,13 +1062,16 @@ class LowerMatrixIntrinsics {
Value *Op1;
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- Changed |= VisitBinaryOperator(BinOp);
+ VisitBinaryOperator(BinOp);
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
- Changed |= VisitUnaryOperator(UnOp);
+ VisitUnaryOperator(UnOp);
if (match(Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ else
+ continue;
+ Changed = true;
}
if (ORE) {
@@ -2105,17 +2108,16 @@ class LowerMatrixIntrinsics {
}
/// Lower load instructions, if shape information is available.
- bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
+ void VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");
LowerLoad(Inst, Ptr, Inst->getAlign(),
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
I->second);
- return true;
}
- bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
+ void VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
@@ -2123,11 +2125,10 @@ class LowerMatrixIntrinsics {
LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
I->second);
- return true;
}
/// Lower binary operators, if shape information is available.
- bool VisitBinaryOperator(BinaryOperator *Inst) {
+ void VisitBinaryOperator(BinaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");
@@ -2155,11 +2156,10 @@ class LowerMatrixIntrinsics {
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Lower unary operators, if shape information is available.
- bool VisitUnaryOperator(UnaryOperator *Inst) {
+ void VisitUnaryOperator(UnaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
assert(I != ShapeMap.end() &&
"must only visit instructions with shape info");
@@ -2191,7 +2191,6 @@ class LowerMatrixIntrinsics {
Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
Result.getNumVectors()),
Builder);
- return true;
}
/// Helper to linearize a matrix expression tree into a string. Currently
>From 42c110b4e09ecf5f470517946777e3afbb13a30c Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 5 Jun 2025 16:43:54 +0100
Subject: [PATCH 4/6] !fixup use else if
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 3158332b9c993..38f92561a917d 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1063,9 +1063,9 @@ class LowerMatrixIntrinsics {
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
VisitBinaryOperator(BinOp);
- if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
+ else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp);
- if (match(Inst, m_Load(m_Value(Op1))))
+ else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
>From 4499f9cde210b49c84b1fa0760df29c372407c75 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 09:38:45 -0700
Subject: [PATCH 5/6] merge fhahn's branch
---
llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 69e494f0a94f3..4d4542c1705e4 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -489,7 +489,7 @@ class LowerMatrixIntrinsics {
DenseMap<Value *, ShapeInfo> ShapeMap;
/// List of instructions to remove. While lowering, we are not replacing all
- /// users of a lowered instruction.and
+ /// users of a lowered instruction, if shape information is available and
/// those need to be removed after we finished lowering.
SmallVector<Instruction *, 16> ToRemove;
>From 506431af4213b263d060f221d802ad6241184a84 Mon Sep 17 00:00:00 2001
From: Jon Roelofs <jonathan_roelofs at apple.com>
Date: Thu, 5 Jun 2025 09:44:07 -0700
Subject: [PATCH 6/6] VisitCallInst can't fail either
---
.../Transforms/Scalar/LowerMatrixIntrinsics.cpp | 15 ++++++---------
1 file changed, 6 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 4d4542c1705e4..2b0993744ff1e 100644
--- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1058,15 +1058,14 @@ class LowerMatrixIntrinsics {
const ShapeInfo &SI = ShapeMap.at(Inst);
- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= tryVisitCallInst(CInst);
-
Value *Op1;
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
VisitBinaryOperator(BinOp, SI);
else if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
VisitUnaryOperator(UnOp, SI);
+ else if (CallInst *CInst = dyn_cast<CallInst>(Inst))
+ VisitCallInst(CInst);
else if (match(Inst, m_Load(m_Value(Op1))))
VisitLoad(cast<LoadInst>(Inst), SI, Op1, Builder);
else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
@@ -1111,10 +1110,9 @@ class LowerMatrixIntrinsics {
return Changed;
}
- /// Replace intrinsic calls
- bool tryVisitCallInst(CallInst *Inst) {
- if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
- return false;
+ /// Replace intrinsic calls.
+ void VisitCallInst(CallInst *Inst) {
+ assert(Inst->getCalledFunction() && Inst->getCalledFunction()->isIntrinsic());
switch (Inst->getCalledFunction()->getIntrinsicID()) {
case Intrinsic::matrix_multiply:
@@ -1130,9 +1128,8 @@ class LowerMatrixIntrinsics {
LowerColumnMajorStore(Inst);
break;
default:
- return false;
+ llvm_unreachable("only intrinsics supporting shape info should be seen here");
}
- return true;
}
/// Compute the alignment for a column/row \p Idx with \p Stride between them.
More information about the llvm-commits
mailing list