[Mlir-commits] [mlir] [mlir][linalg] Add more precise memory effects to linalg op (PR #92079)

donald chen llvmlistbot at llvm.org
Tue May 14 23:37:49 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/92079

>From 9ccde8b8ade06f0e5f5f22183f11d2eeb3829ce3 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Tue, 14 May 2024 08:03:42 +0000
Subject: [PATCH] [mlir][linalg] Add more precise memory effects to linalg op

---
 .../Dialect/Linalg/IR/LinalgInterfaces.td     |  3 +
 .../Dialect/Linalg/IR/LinalgStructuredOps.td  |  3 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      | 56 ++++++++++++-------
 .../mlir-linalg-ods-yaml-gen.cpp              |  3 +-
 4 files changed, 42 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index fbf3f19cde0e9..a94c30d4708a0 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -322,6 +322,9 @@ def LinalgStructuredInterface
       /*args=*/(ins "OpOperand *":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
+        if ($_op.getOperation()->getRegion(0).empty()) {
+          return true;
+        }
         unsigned bbArgNumber = opOperand->getOperandNumber();
         // Init tensors have uses.
         return !getBlock()->getArgument(bbArgNumber).use_empty();
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 5ee363ed32572..8162926dad6ad 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -289,6 +289,9 @@ def MapOp : LinalgStructuredBase_Op<"map", [
 
     bool payloadUsesValueFromOperand(OpOperand * opOperand) {
       if (isDpsInit(opOperand)) return false;
+      if (getOperation()->getRegion(0).empty()) {
+        return true;
+      }
       return !getMatchingBlockArgument(opOperand).use_empty();
     }
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e5f83331baf81..7e802fe15e5b3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1103,20 +1103,27 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
 static void getGenericEffectsImpl(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects,
-    ValueRange results, const ValueRange inputOperands,
-    ValueRange outputOperands) {
-  for (auto operand : inputOperands) {
+    LinalgOp linalgOp) {
+  SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
+  for (auto [index, operand] : llvm::enumerate(inputOperands)) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
-                         SideEffects::DefaultResource::get());
+    if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
+      effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
+                           SideEffects::DefaultResource::get());
+    }
   }
-  for (auto operand : outputOperands) {
+  unsigned inputOperandSize = inputOperands.size();
+
+  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
-                         SideEffects::DefaultResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), operand,
+    if (linalgOp.payloadUsesValueFromOperand(
+            &linalgOp->getOpOperand(index + inputOperandSize))) {
+      effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
+                           SideEffects::DefaultResource::get());
+    }
+    effects.emplace_back(MemoryEffects::Write::get(), operand, 0, true,
                          SideEffects::DefaultResource::get());
   }
 }
@@ -1124,8 +1131,7 @@ static void getGenericEffectsImpl(
 void GenericOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1479,7 @@ ArrayAttr MapOp::getIndexingMaps() {
 void MapOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1542,8 +1547,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
 void ReduceOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1831,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
 void TransposeOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1968,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
 void BroadcastOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2496,20 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
 void SoftmaxOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  for (auto operand : getDpsInputs()) {
+    if (!llvm::isa<MemRefType>(operand.getType()))
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand,
+                         SideEffects::DefaultResource::get());
+  }
+  for (auto operand : getDpsInits()) {
+    if (!llvm::isa<MemRefType>(operand.getType()))
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand,
+                         SideEffects::DefaultResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), operand,
+                         SideEffects::DefaultResource::get());
+  }
 }
 
 // Helper functions for softmax decomposition.
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index fe6ad15041126..882000ee0969a 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
       if (hasPureTensorSemantics()) return;
-      getGenericEffectsImpl(effects,
-        getOperation()->getResults(), getDpsInputs(), getDpsInits());
+      getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 )FMT";
 



More information about the Mlir-commits mailing list