[Mlir-commits] [mlir] a0fdb38 - [mlir][linalg] Add more precise memory effects to linalg op (#92079)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 27 23:39:22 PDT 2024


Author: donald chen
Date: 2024-05-28T08:39:19+02:00
New Revision: a0fdb38a7648f4e2b7c86e2212d7887ac996a57a

URL: https://github.com/llvm/llvm-project/commit/a0fdb38a7648f4e2b7c86e2212d7887ac996a57a
DIFF: https://github.com/llvm/llvm-project/commit/a0fdb38a7648f4e2b7c86e2212d7887ac996a57a.diff

LOG: [mlir][linalg] Add more precise memory effects to linalg op (#92079)

This patch add more precise memory effect to linalg op. Including the
following points:
1. Remove the read side effects for operands that are not used.
2. Set the effect for all side effects to "full".

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 6a5f25a7605f1..0b403e2142c53 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1122,20 +1122,30 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
 static void getGenericEffectsImpl(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects,
-    ValueRange results, const ValueRange inputOperands,
-    ValueRange outputOperands) {
-  for (auto operand : inputOperands) {
+    LinalgOp linalgOp) {
+  SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
+  for (auto [index, operand] : llvm::enumerate(inputOperands)) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
-                         SideEffects::DefaultResource::get());
+    if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
+      effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    }
   }
-  for (auto operand : outputOperands) {
+  unsigned inputOperandSize = inputOperands.size();
+
+  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
     if (!llvm::isa<MemRefType>(operand.getType()))
       continue;
-    effects.emplace_back(MemoryEffects::Read::get(), operand,
-                         SideEffects::DefaultResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), operand,
+    if (linalgOp.payloadUsesValueFromOperand(
+            &linalgOp->getOpOperand(index + inputOperandSize))) {
+      effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+                           /*effectOnFullRegion=*/true,
+                           SideEffects::DefaultResource::get());
+    }
+    effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
                          SideEffects::DefaultResource::get());
   }
 }
@@ -1143,8 +1153,7 @@ static void getGenericEffectsImpl(
 void GenericOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 LogicalResult GenericOp::verify() { return success(); }
@@ -1492,8 +1501,7 @@ ArrayAttr MapOp::getIndexingMaps() {
 void MapOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 //===----------------------------------------------------------------------===//
@@ -1561,8 +1569,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
 void ReduceOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1846,8 +1853,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
 void TransposeOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1984,8 +1990,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
 void BroadcastOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2513,8 +2518,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
 void SoftmaxOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
-                        getDpsInits());
+  for (Value operand : getDpsInputs()) {
+    if (!llvm::isa<MemRefType>(operand.getType()))
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+  }
+  for (Value operand : getDpsInits()) {
+    if (!llvm::isa<MemRefType>(operand.getType()))
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+    effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+                         /*effectOnFullRegion=*/true,
+                         SideEffects::DefaultResource::get());
+  }
 }
 
 // Helper functions for softmax decomposition.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 899b8c87d0df7..81a5398dabcb7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -76,10 +76,16 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
   // new op. Since the new op does not have any tensor results, it does not
   // return anything.
   assert(op->getNumRegions() == 1 && "expected that op has 1 region");
-  auto newOp = cast<DestinationStyleOpInterface>(cloneWithoutRegions(
-      rewriter, op, /*newResultTypes=*/TypeRange{}, newOperands));
-  rewriter.inlineRegionBefore(op->getRegion(0), newOp->getRegion(0),
-                              newOp->getRegion(0).begin());
+  OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
+                       op->getAttrs());
+  state.addRegion();
+  Operation *newOp = Operation::create(state);
+  newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
+                                         op->getRegion(0).getBlocks());
+
+  // We don't want the rewriter tracks an incomplete operation, so insert new
+  // operation after op was fully constructed.
+  rewriter.insert(newOp);
 
   // Replace the results of the old op with the new output buffers.
   replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers);

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index 37240164c377e..7311cdd39d075 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -667,8 +667,7 @@ LogicalResult {0}::fold(FoldAdaptor,
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
       if (hasPureTensorSemantics()) return;
-      getGenericEffectsImpl(effects,
-        getOperation()->getResults(), getDpsInputs(), getDpsInits());
+      getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
 }
 )FMT";
 


        


More information about the Mlir-commits mailing list