[Mlir-commits] [mlir] f24d931 - [mlir][linalg][bufferize][NFC] Specify bufferize traversal in `bufferize`

Matthias Springer llvmlistbot at llvm.org
Tue Nov 23 04:36:40 PST 2021


Author: Matthias Springer
Date: 2021-11-23T21:33:19+09:00
New Revision: f24d9313cc9fe9f6cd70f606c1dc8f8213587468

URL: https://github.com/llvm/llvm-project/commit/f24d9313cc9fe9f6cd70f606c1dc8f8213587468
DIFF: https://github.com/llvm/llvm-project/commit/f24d9313cc9fe9f6cd70f606c1dc8f8213587468.diff

LOG: [mlir][linalg][bufferize][NFC] Specify bufferize traversal in `bufferize`

The interface method `bufferize` controls how (and it what order) nested ops are traversed. This simplifies bufferization of scf::ForOps and scf::IfOps, which used to need special rules in scf::YieldOp.

Differential Revision: https://reviews.llvm.org/D114057

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 491f8a56eb609..881f1edb11c47 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,10 +297,16 @@ struct BufferizationState {
 /// bufferization is necessary.
 Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
 
+/// Bufferize all ops in the given region.
+LogicalResult bufferize(Region *region, BufferizationState &state);
+
+/// Bufferize all ops in the given block.
+LogicalResult bufferize(Block *block, BufferizationState &state);
+
 /// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
 /// function returns immediately. Otherwise, it calls the `bufferize` interface
 /// method of `BufferizableOpInterface`.
-LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
+LogicalResult bufferize(Operation *op, BufferizationState &state);
 
 /// PostAnalysisSteps can be registered with `BufferizationOptions` and are
 /// executed after the analysis, but before bufferization. They can be used

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 757eca50eb5f0..ca0454549bcab 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -163,8 +163,13 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Bufferize this op, i.e., rewrite it into a memref-based equivalent.
-          `bvm` maps tensor values to memref values and this method should map
-          tensor results to memref results after creating/modifying ops.
+          Tensor values should be mapped to buffer values using `state`.
+
+          Implementations are required to required to bufferize nested ops
+          before returning. Otherwise, nested ops will not be bufferized.
+
+          This method will never be called on ops that do not have at least one
+          tensor operand or result.
         }],
         /*retType=*/"LogicalResult",
         /*methodName=*/"bufferize",

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3897734a1898c..fc9f414f7cd9d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -392,8 +392,26 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
 }
 
 LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
-                                                   BufferizationState &state) {
+mlir::linalg::comprehensive_bufferize::bufferize(Region *region,
+                                                 BufferizationState &state) {
+  for (Block &block : *region)
+    if (failed(bufferize(&block, state)))
+      return failure();
+  return success();
+}
+
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
+                                                 BufferizationState &state) {
+  for (Operation &op : *block)
+    if (failed(bufferize(&op, state)))
+      return failure();
+  return success();
+}
+
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
+                                                 BufferizationState &state) {
   OpBuilder b(op->getContext());
 
   // Skip BufferCast and TensorLoad ops.
@@ -404,15 +422,22 @@ mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
   bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
-  if (!hasTensorResult && !hasTensorOperand)
+
+  // No tensor results or operands: Simply bufferize all nested ops.
+  if (!hasTensorResult && !hasTensorOperand) {
+    for (Region &region : op->getRegions())
+      if (failed(bufferize(&region, state)))
+        return failure();
     return success();
+  }
 
-  // Bufferize using `BufferizableOpInterface`.
+  // Bufferize using `BufferizableOpInterface`. Interface implementations are
+  // responsible for bufferizing nested ops.
   b.setInsertionPoint(op);
   if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
     return bufferizableOp.bufferize(b, state);
 
-  // Other op with tensors. No bufferization method specified.
+  // Emit error if tensor op is not bufferizable.
   return op->emitError() << "unsupported op with tensors";
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index ac67d86ac9b89..d062bbab4ad0c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -784,11 +784,12 @@ static Value createNewAllocDeallocPairForShapedValue(
 //===----------------------------------------------------------------------===//
 
 /// FuncOp always creates TensorToMemRef ops.
-static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
-                               BufferizationState &state) {
+static LogicalResult bufferizeFuncOp(FuncOp funcOp, BufferizationState &state) {
   // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
+  OpBuilder b(funcOp->getContext());
   b.setInsertionPointToStart(&funcOp.body().front());
+
+  // Create BufferCastOps for function args.
   for (auto bbArg : funcOp.getArguments()) {
     auto tensorType = bbArg.getType().dyn_cast<TensorType>();
     if (!tensorType)
@@ -804,7 +805,9 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
     state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
     state.mapBuffer(bbArg, bufferCast);
   }
-  return success();
+
+  // Bufferize function body.
+  return bufferize(&funcOp.body(), state);
 }
 
 //===----------------------------------------------------------------------===//
@@ -923,37 +926,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
   return res;
 }
 
-//===----------------------------------------------------------------------===//
-// Bufferization entry-point for functions.
-//===----------------------------------------------------------------------===//
-
-static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
-                                              BufferizationState &state) {
-  LLVM_DEBUG(llvm::dbgs() << "\n\n");
-  LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
-  OpBuilder b(funcOp->getContext());
-
-  // Start by bufferizing `funcOp` arguments.
-  if (failed(bufferize(b, funcOp, state)))
-    return failure();
-
-  auto walkFunc = [&](Operation *op) -> WalkResult {
-    if (failed(bufferizeOp(op, state)))
-      return failure();
-    return success();
-  };
-
-  // Bufferize ops pre-order, i.e., bufferize ops first, then their children.
-  // This is needed for ops with blocks that have BlockArguments. These must be
-  // mapped before bufferizing the children.
-  if (funcOp.walk<WalkOrder::PreOrder>(walkFunc).wasInterrupted())
-    return failure();
-
-  LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Bufferization entry-point for modules.
 //===----------------------------------------------------------------------===//
@@ -1380,7 +1352,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     // Bufferization phase.
     if (!options.testAnalysisOnly) {
       // Bufferize all ops in funcOp.
-      if (failed(bufferizeFuncOpInternals(funcOp, state)))
+      if (failed(bufferizeFuncOp(funcOp, state)))
         return failure();
 
       // Erase all obsolete ops.
@@ -1547,12 +1519,13 @@ struct ExecuteRegionOpInterface
                           BufferizationState &state) const {
     // TODO: Add bufferization support when needed. scf.execute_region should be
     // bufferized similar to scf.if.
+    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
     bool hasTensorReturnType = any_of(
         op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
     if (hasTensorReturnType)
       return op->emitError(
           "scf.execute_region with tensor result not supported");
-    return success();
+    return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state);
   }
 };
 
@@ -1609,37 +1582,33 @@ struct IfOpInterface
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
-    // scf::IfOp is bufferized after scf::YieldOp in the else branch.
-    return success();
-  }
-};
+    auto ifOp = cast<scf::IfOp>(op);
 
-/// Bufferize the scf::IfOp. This function is called after the YieldOp was
-/// bufferized.
-static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
-                                   BufferizationState &state) {
-  // Take a guard before anything else.
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(ifOp);
+    // Bufferize then/else blocks.
+    if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
+      return failure();
+    if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
+      return failure();
 
-  for (OpResult opResult : ifOp->getResults()) {
-    if (!opResult.getType().isa<TensorType>())
-      continue;
-    // TODO: Atm we bail on unranked TensorType because we don't know how to
-    // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
-    assert(opResult.getType().isa<RankedTensorType>() &&
-           "unsupported unranked tensor");
+    for (OpResult opResult : ifOp->getResults()) {
+      if (!opResult.getType().isa<TensorType>())
+        continue;
+      // TODO: Atm we bail on unranked TensorType because we don't know how to
+      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+      assert(opResult.getType().isa<RankedTensorType>() &&
+             "unsupported unranked tensor");
 
-    Value resultBuffer = getResultBuffer(b, opResult, state);
-    if (!resultBuffer)
-      return failure();
+      Value resultBuffer = getResultBuffer(b, opResult, state);
+      if (!resultBuffer)
+        return failure();
 
-    state.aliasInfo.createAliasInfoEntry(resultBuffer);
-    state.mapBuffer(opResult, resultBuffer);
-  }
+      state.aliasInfo.createAliasInfoEntry(resultBuffer);
+      state.mapBuffer(opResult, resultBuffer);
+    }
 
-  return success();
-}
+    return success();
+  }
+};
 
 struct ForOpInterface
     : public BufferizableOpInterface::ExternalModel<ForOpInterface,
@@ -1687,9 +1656,6 @@ struct ForOpInterface
 
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
-    // Note: This method is just setting up the mappings for the block arguments
-    // and the result buffer. The op is bufferized after the scf::YieldOp.
-
     auto forOp = cast<scf::ForOp>(op);
 
     // Take a guard before anything else.
@@ -1716,41 +1682,39 @@ struct ForOpInterface
       state.mapBuffer(opResult, resultBuffer);
     }
 
-    return success();
-  }
-};
+    // Bufferize loop body.
+    if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state)))
+      return failure();
 
-/// Bufferize the scf::ForOp. This function is called after the YieldOp was
-/// bufferized.
-static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
-                                    BufferizationState &state) {
-  auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
-  for (OpOperand &operand : yieldOp->getOpOperands()) {
-    auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-    if (!tensorType)
-      continue;
+    // Finish bufferizing scf::ForOp.
+    auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
+    for (OpOperand &operand : yieldOp->getOpOperands()) {
+      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
 
-    OpOperand &forOperand = forOp.getOpOperandForResult(
-        forOp->getResult(operand.getOperandNumber()));
-    auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    Value yieldedBuffer = state.lookupBuffer(operand.get());
-    Value bbArgBuffer = state.lookupBuffer(bbArg);
-    if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
-                                                       bbArgBuffer)) {
-      // TODO: this could get resolved with copies but it can also turn into
-      // swaps so we need to be careful about order of copies.
-      return yieldOp->emitError()
-             << "Yield operand #" << operand.getOperandNumber()
-             << " does not bufferize to an equivalent buffer to the matching"
-             << " enclosing scf::for operand";
-    }
+      OpOperand &forOperand = forOp.getOpOperandForResult(
+          forOp->getResult(operand.getOperandNumber()));
+      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+      Value yieldedBuffer = state.lookupBuffer(operand.get());
+      Value bbArgBuffer = state.lookupBuffer(bbArg);
+      if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
+                                                         bbArgBuffer)) {
+        // TODO: this could get resolved with copies but it can also turn into
+        // swaps so we need to be careful about order of copies.
+        return yieldOp->emitError()
+               << "Yield operand #" << operand.getOperandNumber()
+               << " does not bufferize to an equivalent buffer to the matching"
+               << " enclosing scf::for operand";
+      }
 
-    // Buffers are equivalent so the work is already done and we just yield
-    // the bbArg so that it later canonicalizes away.
-    operand.set(bbArg);
+      // Buffers are equivalent so the work is already done and we just yield
+      // the bbArg so that it later canonicalizes away.
+      operand.set(bbArg);
+    }
+    return success();
   }
-  return success();
-}
+};
 
 struct YieldOpInterface
     : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
@@ -1774,27 +1738,10 @@ struct YieldOpInterface
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     auto yieldOp = cast<scf::YieldOp>(op);
-
-    if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
-      if (execOp->getNumResults() != 0)
-        return execOp->emitError(
-            "expected result-less scf.execute_region containing op");
-      return success();
-    }
-
-    // Bufferize scf::IfOp after bufferizing the scf::YieldOp in the else
-    // branch.
-    if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
-      if (ifOp.elseYield() != yieldOp)
-        return success();
-      return bufferizeIfOp(ifOp, b, state);
-    }
-
-    // Bufferize scf::ForOp after bufferizing the scf::YieldOp.
-    if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
-      return bufferizeForOp(forOp, b, state);
-
-    return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
+    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
+            yieldOp->getParentOp()))
+      return yieldOp->emitError("unsupported scf::YieldOp parent");
+    return success();
   }
 };
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 83fde817ef842..52fed305d06cd 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -340,7 +340,8 @@ struct TiledLoopOpInterface
              static_cast<int>(oldInputs.size()) + numNewInputBuffers,
              static_cast<int>(oldOutputs.size()) + numNewOutputBuffers}));
 
-    return success();
+    // Bufferize loop body.
+    return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state);
   }
 };
 


        


More information about the Mlir-commits mailing list