[Mlir-commits] [mlir] 5fa0b35 - [mlir][linalg][bufferize] Implement equivalence analysis

Matthias Springer llvmlistbot at llvm.org
Fri Dec 3 20:00:32 PST 2021


Author: Matthias Springer
Date: 2021-12-04T11:52:04+09:00
New Revision: 5fa0b3561a541e992486e29205388e6976c5d77f

URL: https://github.com/llvm/llvm-project/commit/5fa0b3561a541e992486e29205388e6976c5d77f
DIFF: https://github.com/llvm/llvm-project/commit/5fa0b3561a541e992486e29205388e6976c5d77f.diff

LOG: [mlir][linalg][bufferize] Implement equivalence analysis

Instead of checking buffer equivalence during bufferization, gather buffer equivalence information right after the analysis. This is in preparation of decoupling bufferization from BufferizationAliasInfo.

This change also fixes equivalence analysis for scf.if op results, which was not fully implemented. scf.if op results are equivalent to their corresponding yield values if both yield values are equivalent.

Differential Revision: https://reviews.llvm.org/D114774

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 1b3b3ff2f12e1..2d690b5f1045e 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -245,10 +245,6 @@ bool bufferizesToAliasOnly(OpOperand &opOperand);
 /// themselves (e.g., ExtractSliceOp).
 bool isValueRead(Value value);
 
-/// Return the relationship between the operand and the its corresponding
-/// OpResult that it may alias with. Return None if the op is not bufferizable.
-BufferRelation bufferRelation(OpOperand &opOperand);
-
 /// Starting from `value`, follow the use-def chain in reverse, always selecting
 /// the aliasing OpOperands. Find and return Values for which `condition`
 /// evaluates to true. OpOperands of such matching Values are not traversed any
@@ -426,7 +422,8 @@ struct AllocationHoistingBarrierOnly
     return OpResult();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::None;
   }
 

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index d792a768b1e9c..cf083eb1986d9 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -167,19 +167,23 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
       >,
       InterfaceMethod<
         /*desc=*/[{
-          Return the buffer relation between the given OpOperand and its
-          aliasing OpResult when bufferized in-place. Most OpOperands have an
-          "equivalence" relation.
+          Return the buffer relation between the given OpResult and its aliasing
+          OpOperands when bufferized in-place. Most OpOperands have an
+          "equivalence" relation. This method will never be called on OpResults
+          that do not have a tensor type. It will also never be called on
+          OpResults that do not have at least one aliasing OpOperand.
 
           TODO: Support other relations such as "OpOperand is included in
           OpResult".
         }],
         /*retType=*/"BufferRelation",
         /*methodName=*/"bufferRelation",
-        /*args=*/(ins "OpOperand &":$opOperand),
+        /*args=*/(ins "OpResult":$opResult,
+                      "const BufferizationAliasInfo &":$aliasInfo),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          // Does not have to be implemented for ops without tensor OpOperands.
+          // Does not have to be implemented for ops without tensor OpResults
+          // that have an aliasing OpOperand.
           llvm_unreachable("bufferRelation not implemented");
         }]
       >,

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
index 97ae11e5d27b6..0a4b140a1f961 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H
 
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+
 namespace mlir {
 
 class DialectRegistry;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4348fe4d5ad28..03ea6bdd63b97 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -167,8 +167,6 @@ void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
 
   markInPlace(result);
   aliasInfo.unionSets(result, operand.get());
-  if (bufferRelation(operand) == BufferRelation::Equivalent)
-    equivalentInfo.unionSets(result, operand.get());
 }
 
 /// Set the inPlace bufferization spec to false.
@@ -303,19 +301,6 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
   return false;
 }
 
-/// Return the relationship between the operand and the its corresponding
-/// OpResult that it may alias with. Return None if the op is not bufferizable.
-BufferRelation
-mlir::linalg::comprehensive_bufferize::bufferRelation(OpOperand &opOperand) {
-  if (auto bufferizableOp =
-          dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
-    return bufferizableOp.bufferRelation(opOperand);
-
-  // Unknown op that returns a tensor. The inplace analysis does not support it.
-  // Conservatively return None.
-  return BufferRelation::None;
-}
-
 // Starting from `value`, follow the use-def chain in reverse, always selecting
 // the aliasing OpOperands. Find and return Values for which `condition`
 // evaluates to true. OpOperands of such matching Values are not traversed any

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6407febd8b6f9..23da486b34f69 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -640,6 +640,40 @@ static LogicalResult inPlaceAnalysis(Operation *op,
   return inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed);
 }
 
+/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
+static void equivalenceAnalysis(SmallVector<Operation *> &ops,
+                                BufferizationAliasInfo &aliasInfo) {
+  for (Operation *op : ops)
+    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+      for (OpResult opResult : op->getOpResults())
+        if (opResult.getType().isa<TensorType>())
+          if (aliasInfo.isInPlace(opResult)) {
+            SmallVector<OpOperand *> opOperands =
+                bufferizableOp.getAliasingOpOperand(opResult);
+            if (!opOperands.empty())
+              if (bufferizableOp.bufferRelation(opResult, aliasInfo) ==
+                  BufferRelation::Equivalent)
+                for (OpOperand *opOperand : opOperands)
+                  aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
+          }
+}
+
+/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
+/// in `op`.
+static void equivalenceAnalysis(Operation *op,
+                                BufferizationAliasInfo &aliasInfo) {
+  // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
+  SmallVector<Operation *> ops;
+  op->walk<WalkOrder::PostOrder>([&](Operation *op) {
+    // No tensors => no buffers.
+    if (none_of(op->getResultTypes(), isaTensor))
+      return;
+    ops.push_back(op);
+  });
+
+  equivalenceAnalysis(ops, aliasInfo);
+}
+
 /// Assert that the current bufferization decisions are consistent.
 static LogicalResult
 checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
@@ -708,6 +742,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   if (failed(
           inPlaceAnalysis(op, aliasInfo, domInfo, options.analysisFuzzerSeed)))
     return failure();
+  equivalenceAnalysis(op, aliasInfo);
 
   for (const std::unique_ptr<PostAnalysisStep> &step :
        options.postAnalysisSteps) {
@@ -717,6 +752,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     // Analyze ops that were created by the PostAnalysisStep.
     if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
       return failure();
+    equivalenceAnalysis(newOps, aliasInfo);
   }
 
   // Annotate operations if we only want to report the analysis.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 3f14975386a51..6a3a25c935829 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -132,7 +132,8 @@ struct LinalgOpInterface
     return genericOp->getResult(outputOperandIndex - numOutputBuffers);
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 
@@ -187,7 +188,8 @@ struct TiledLoopOpInterface
     return tiledLoopOp.getTiedOpResult(opOperand);
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 
@@ -409,8 +411,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
             // TODO: Support cases such as extract_slice(init_tensor).
             SmallVector<OpOperand *> opOperands =
                 getAliasingOpOperand(opResult);
-            if (!llvm::all_of(opOperands, [](OpOperand *operand) {
-                  return bufferRelation(*operand) == BufferRelation::Equivalent;
+            if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
+                  return aliasInfo.areEquivalentBufferizedValues(operand->get(),
+                                                                 opResult);
                 }))
               return true;
             return false;

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 30f51d5d2ca38..4b014c9198dc8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -454,10 +454,6 @@ struct CallOpInterface
     return OpResult();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
   /// In a first approximation, all the function arguments of a FuncOp are
   /// marked inplaceable. For now, it is the responsibility of the `callOp`
   /// bufferization to allow FuncOp that are inplaceable to write inPlace.

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 5dc434335e1ef..c41457d7da76f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -67,6 +67,11 @@ struct ExecuteRegionOpInterface
           "scf.execute_region with tensor result not supported");
     return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state);
   }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
+    return BufferRelation::Equivalent;
+  }
 };
 
 struct IfOpInterface
@@ -148,6 +153,19 @@ struct IfOpInterface
 
     return success();
   }
+
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
+    // IfOp results are equivalent to their corresponding yield values if both
+    // yield values are equivalent to each other.
+    auto bufferizableOp = cast<BufferizableOpInterface>(op);
+    SmallVector<OpOperand *> yieldValues =
+        bufferizableOp.getAliasingOpOperand(opResult);
+    assert(yieldValues.size() == 2 && "expected 2 yield values");
+    bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
+        yieldValues[0]->get(), yieldValues[1]->get());
+    return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
+  }
 };
 
 struct ForOpInterface
@@ -174,8 +192,17 @@ struct ForOpInterface
     return forOp.getResultForOpOperand(opOperand);
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
+    // ForOp results are equivalent to their corresponding init_args if the
+    // corresponding iter_args and yield values are equivalent.
+    auto forOp = cast<scf::ForOp>(op);
+    OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
+    auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+    auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
+    bool equivalentYield = aliasInfo.areEquivalentBufferizedValues(
+        bbArg, yieldOp->getOperand(opResult.getResultNumber()));
+    return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
   }
 
   bool isWritable(Operation *op, Value value) const {
@@ -230,10 +257,8 @@ struct ForOpInterface
       OpOperand &forOperand = forOp.getOpOperandForResult(
           forOp->getResult(operand.getOperandNumber()));
       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      Value yieldedBuffer = state.lookupBuffer(operand.get());
-      Value bbArgBuffer = state.lookupBuffer(bbArg);
-      if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
-                                                         bbArgBuffer)) {
+      if (!state.aliasInfo.areEquivalentBufferizedValues(operand.get(),
+                                                         bbArg)) {
         // TODO: this could get resolved with copies but it can also turn into
         // swaps so we need to be careful about order of copies.
         return yieldOp->emitError()
@@ -265,10 +290,6 @@ struct YieldOpInterface
     return OpResult();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 11333807dd7a8..f595de42b7e7e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -52,7 +52,8 @@ struct CastOpInterface
     return op->getResult(0);
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 
@@ -129,7 +130,8 @@ struct ExtractSliceOpInterface
                : OpResult();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::None;
   }
 
@@ -235,7 +237,8 @@ struct InsertOpInterface
     return success();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 };
@@ -307,7 +310,8 @@ struct InsertSliceOpInterface
                : OpResult();
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 3fafa75aa79bf..c2f33b876fff3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -70,7 +70,8 @@ struct TransferWriteOpInterface
     return op->getOpResult(0);
   }
 
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+  BufferRelation bufferRelation(Operation *op, OpResult opResult,
+                                const BufferizationAliasInfo &aliasInfo) const {
     return BufferRelation::Equivalent;
   }
 


        


More information about the Mlir-commits mailing list