[Mlir-commits] [mlir] [mlir][linalg] Add more precise memory effects to linalg op (PR #92079)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 14 01:11:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

This patch add more precise memory effect to linalg op. Including the following points:
1. Remove the read side effects for operands that are not used.
2. Set the effect for all side effects to "full".

---
Full diff: https://github.com/llvm/llvm-project/pull/92079.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+26-12) 
- (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+1-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..5958e1a0f3206 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1103,20 +1103,28 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
 static void getGenericEffectsImpl(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects,
-    ValueRange results, const ValueRange inputOperands,
+    LinalgOp linalgOp, ValueRange results, const ValueRange inputOperands,
     ValueRange outputOperands) {
   for (auto operand : inputOperands) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
+    effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
                          SideEffects::DefaultResource::get());
   }
-  for (auto operand : outputOperands) {
+  unsigned inputOperandSize = inputOperands.size();
+  unsigned usedOutputSize =
+      linalgOp.getOpOperandsMatchingBBargs().size() - inputOperandSize;
+
+  for (auto [index, operand] : llvm::enumerate(outputOperands)) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
-                         SideEffects::DefaultResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), operand,
+    if (index < usedOutputSize &&
+        linalgOp.payloadUsesValueFromOperand(
+            &linalgOp->getOpOperand(index + inputOperandSize))) {
+      effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
+                           SideEffects::DefaultResource::get());
+    }
+    effects.emplace_back(MemoryEffects::Write::get(), 0, true, operand,
                          SideEffects::DefaultResource::get());
   }
 }
@@ -1124,7 +1132,8 @@ static void getGenericEffectsImpl(
 void GenericOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
@@ -1473,7 +1482,8 @@ ArrayAttr MapOp::getIndexingMaps() {
 void MapOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
@@ -1542,7 +1552,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
 void ReduceOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
@@ -1827,7 +1838,8 @@ ArrayAttr TransposeOp::getIndexingMaps() {
 void TransposeOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
@@ -1965,7 +1977,8 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
 void BroadcastOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
@@ -2494,7 +2507,8 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
 void SoftmaxOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
+                        getOperation()->getResults(), getDpsInputs(),
                         getDpsInits());
 }
 
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..f3071b81e21cb 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -659,7 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
       if (hasPureTensorSemantics()) return;
-      getGenericEffectsImpl(effects,
+      getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
         getOperation()->getResults(), getDpsInputs(), getDpsInits());
 }
 )FMT";

``````````

</details>


https://github.com/llvm/llvm-project/pull/92079


More information about the Mlir-commits mailing list