[Mlir-commits] [mlir] [mlir][linalg] Add more precise memory effects to linalg op (PR #92079)
donald chen
llvmlistbot at llvm.org
Tue May 14 01:11:02 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/92079
This patch add more precise memory effect to linalg op. Including the following points:
1. Remove the read side effects for operands that are not used.
2. Set the effect for all side effects to "full".
>From b81c63af192ce41b122b8e08c834b61d91547797 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Tue, 14 May 2024 08:03:42 +0000
Subject: [PATCH] [mlir][linalg] Add more precise memory effects to linalg op
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 38 +++++++++++++------
.../mlir-linalg-ods-yaml-gen.cpp | 2 +-
2 files changed, 27 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..5958e1a0f3206 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1103,20 +1103,28 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
- ValueRange results, const ValueRange inputOperands,
+ LinalgOp linalgOp, ValueRange results, const ValueRange inputOperands,
ValueRange outputOperands) {
for (auto operand : inputOperands) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
+ effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
SideEffects::DefaultResource::get());
}
- for (auto operand : outputOperands) {
+ unsigned inputOperandSize = inputOperands.size();
+ unsigned usedOutputSize =
+ linalgOp.getOpOperandsMatchingBBargs().size() - inputOperandSize;
+
+ for (auto [index, operand] : llvm::enumerate(outputOperands)) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), operand,
+ if (index < usedOutputSize &&
+ linalgOp.payloadUsesValueFromOperand(
+ &linalgOp->getOpOperand(index + inputOperandSize))) {
+ effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
+ SideEffects::DefaultResource::get());
+ }
+ effects.emplace_back(MemoryEffects::Write::get(), 0, true, operand,
SideEffects::DefaultResource::get());
}
}
@@ -1124,7 +1132,8 @@ static void getGenericEffectsImpl(
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1473,7 +1482,8 @@ ArrayAttr MapOp::getIndexingMaps() {
void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1542,7 +1552,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1827,7 +1838,8 @@ ArrayAttr TransposeOp::getIndexingMaps() {
void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -1965,7 +1977,8 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
void BroadcastOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
@@ -2494,7 +2507,8 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
void SoftmaxOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+ getOperation()->getResults(), getDpsInputs(),
getDpsInits());
}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..f3071b81e21cb 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -659,7 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
if (hasPureTensorSemantics()) return;
- getGenericEffectsImpl(effects,
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
getOperation()->getResults(), getDpsInputs(), getDpsInits());
}
)FMT";
More information about the Mlir-commits
mailing list