[Mlir-commits] [mlir] [mlir][linalg] Add more precise memory effects to linalg op (PR #92079)
donald chen
llvmlistbot at llvm.org
Fri May 17 06:17:59 PDT 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/92079
>From 67321b7c92651e16dcfe2bef86dc69d6e2e4b396 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Tue, 14 May 2024 08:03:42 +0000
Subject: [PATCH] [mlir][linalg] Add more precise memory effects to linalg op
---
.../Dialect/Linalg/IR/LinalgInterfaces.td | 3 +
.../Dialect/Linalg/IR/LinalgStructuredOps.td | 3 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 62 ++++++++++++-------
.../mlir-linalg-ods-yaml-gen.cpp | 3 +-
4 files changed, 48 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..a94c30d4708a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -322,6 +322,9 @@ def LinalgStructuredInterface
/*args=*/(ins "OpOperand *":$opOperand),
/*methodBody=*/"",
/*defaultImplementation=*/[{
+ if ($_op.getOperation()->getRegion(0).empty()) {
+ return true;
+ }
unsigned bbArgNumber = opOperand->getOperandNumber();
// Init tensors have uses.
return !getBlock()->getArgument(bbArgNumber).use_empty();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363ed32572..8162926dad6ad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -289,6 +289,9 @@ def MapOp : LinalgStructuredBase_Op<"map", [
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
if (isDpsInit(opOperand)) return false;
+ if (getOperation()->getRegion(0).empty()) {
+ return true;
+ }
return !getMatchingBlockArgument(opOperand).use_empty();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..3cdd606297091 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1103,20 +1103,30 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
- ValueRange results, const ValueRange inputOperands,
- ValueRange outputOperands) {
- for (auto operand : inputOperands) {
+ LinalgOp linalgOp) {
+ SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
+ for (auto [index, operand] : llvm::enumerate(inputOperands)) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
- SideEffects::DefaultResource::get());
+ if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
+ effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
}
- for (auto operand : outputOperands) {
+ unsigned inputOperandSize = inputOperands.size();
+
+ for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand,
- SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), operand,
+ if (linalgOp.payloadUsesValueFromOperand(
+ &linalgOp->getOpOperand(index + inputOperandSize))) {
+ effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
}
@@ -1124,8 +1134,7 @@ static void getGenericEffectsImpl(
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1482,7 @@ ArrayAttr MapOp::getIndexingMaps() {
void MapOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
//===----------------------------------------------------------------------===//
@@ -1542,8 +1550,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
void ReduceOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1834,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
void TransposeOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1971,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
void BroadcastOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2499,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
void SoftmaxOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
- getDpsInits());
+ for (Value operand : getDpsInputs()) {
+ if (!llvm::isa<MemRefType>(operand.getType()))
+ continue;
+ effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
+ for (Value operand : getDpsInits()) {
+ if (!llvm::isa<MemRefType>(operand.getType()))
+ continue;
+ effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+ /*effectOnFullRegion=*/true,
+ SideEffects::DefaultResource::get());
+ }
}
// Helper functions for softmax decomposition.
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..882000ee0969a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
void {0}::getEffects(SmallVectorImpl<
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
if (hasPureTensorSemantics()) return;
- getGenericEffectsImpl(effects,
- getOperation()->getResults(), getDpsInputs(), getDpsInits());
+ getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}
)FMT";
More information about the Mlir-commits
mailing list