[Mlir-commits] [mlir] 48107ea - [mlir][linalg][bufferize][NFC] Move SCF interface impl to new build target

Matthias Springer llvmlistbot at llvm.org
Thu Nov 25 02:04:06 PST 2021


Author: Matthias Springer
Date: 2021-11-25T19:00:17+09:00
New Revision: 48107eaa07e26f9bc5b24af2d5351793cc64db46

URL: https://github.com/llvm/llvm-project/commit/48107eaa07e26f9bc5b24af2d5351793cc64db46
DIFF: https://github.com/llvm/llvm-project/commit/48107eaa07e26f9bc5b24af2d5351793cc64db46.diff

LOG: [mlir][linalg][bufferize][NFC] Move SCF interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the SCF dialect.

Differential Revision: https://reviews.llvm.org/D114221

Added: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
new file mode 100644
index 0000000000000..97ae11e5d27b6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- SCFInterfaceImpl.h - SCF Impl. of BufferizableOpInterface ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace scf_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace scf_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCF_INTERFACE_IMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index b94694a3e53cd..7eadebd09aeba 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES
   BufferizableOpInterface.cpp
   ComprehensiveBufferize.cpp
   LinalgInterfaceImpl.cpp
+  SCFInterfaceImpl.cpp
   TensorInterfaceImpl.cpp
   VectorInterfaceImpl.cpp
 )
@@ -39,6 +40,15 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   MLIRTensor
 )
 
+add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
+  SCFInterfaceImpl.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRBufferizableOpInterface
+  MLIRIR
+  MLIRSCF
+)
+
 add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
   TensorInterfaceImpl.cpp
 
@@ -67,7 +77,6 @@ add_mlir_dialect_library(MLIRComprehensiveBufferize
   MLIRInferTypeOpInterface
   MLIRIR
   MLIRMemRef
-  MLIRSCF
   MLIRStandard
   MLIRStandardOpsTransforms
 )

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 2b55d22ad4962..fda155c89c7b1 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -112,7 +112,6 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BlockAndValueMapping.h"
@@ -1287,267 +1286,6 @@ BufferizationOptions::BufferizationOptions()
 namespace mlir {
 namespace linalg {
 namespace comprehensive_bufferize {
-namespace scf_ext {
-
-struct ExecuteRegionOpInterface
-    : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
-                                                    scf::ExecuteRegionOp> {
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
-    // any SSA value that is in scope. To allow for use-def chain traversal
-    // through ExecuteRegionOps in the analysis, the corresponding yield value
-    // is considered to be aliasing with the result.
-    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
-    size_t resultNum = std::distance(op->getOpResults().begin(),
-                                     llvm::find(op->getOpResults(), opResult));
-    assert(executeRegionOp.region().getBlocks().size() == 1 &&
-           "expected exactly 1 block");
-    auto yieldOp = dyn_cast<scf::YieldOp>(
-        executeRegionOp.region().front().getTerminator());
-    assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
-    return {&yieldOp->getOpOperand(resultNum)};
-  }
-
-  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
-    // ExecuteRegionOp results always bufferize in-place. Since they have no
-    // OpOperands, they are mostly ignored by the analysis once alias sets are
-    // set up.
-    return true;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    // TODO: Add bufferization support when needed. scf.execute_region should be
-    // bufferized similar to scf.if.
-    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
-    bool hasTensorReturnType = any_of(
-        op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
-    if (hasTensorReturnType)
-      return op->emitError(
-          "scf.execute_region with tensor result not supported");
-    return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state);
-  }
-};
-
-struct IfOpInterface
-    : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    // IfOps do not have tensor OpOperands. The yielded value can be any SSA
-    // value that is in scope. To allow for use-def chain traversal through
-    // IfOps in the analysis, both corresponding yield values from the then/else
-    // branches are considered to be aliasing with the result.
-    auto ifOp = cast<scf::IfOp>(op);
-    size_t resultNum = std::distance(op->getOpResults().begin(),
-                                     llvm::find(op->getOpResults(), opResult));
-    return {&ifOp.thenYield()->getOpOperand(resultNum),
-            &ifOp.elseYield()->getOpOperand(resultNum)};
-  }
-
-  // TODO: For better bufferization results, this could return `true` only if
-  // there is a memory write in one (or both) of the branches. Since this is not
-  // allowed at the moment, we should never encounter scf.ifs that yield
-  // unmodified tensors. Such scf.yield ops could just fold away.
-  bool isMemoryWrite(Operation *op, OpResult opResult) const {
-    // IfOp results are always considered memory writes in the analysis. This
-    // design decision simplifies the analysis considerably. E.g., consider the
-    // following test case:
-    //
-    // %0 = "some_writing_op" : tensor<?xf32>
-    // %r = scf.if %c -> (tensor<?xf32>) {
-    //   scf.yield %0
-    // } else {
-    //   %1 = "another_writing_op"(%0) : tensor<?xf32>
-    // }
-    // "some_reading_op"(%r)
-    //
-    // "another_writing_op" in the above example should be able to bufferize
-    // inplace in the absence of another read of %0. However, if the scf.if op
-    // would not be considered a "write", the analysis would detect the
-    // following conflict:
-    //
-    // * read = some_reading_op
-    // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
-    // * conflictingWrite = %1
-    //
-    // For more details, check the "scf.IfOp" section of the design document.
-    return true;
-  }
-
-  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
-    // IfOp results always bufferize in-place. Since they have no OpOperands,
-    // they are mostly ignored by the analysis once alias sets are set up.
-    return true;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto ifOp = cast<scf::IfOp>(op);
-
-    // Bufferize then/else blocks.
-    if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
-      return failure();
-    if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
-      return failure();
-
-    for (OpResult opResult : ifOp->getResults()) {
-      if (!opResult.getType().isa<TensorType>())
-        continue;
-      // TODO: Atm we bail on unranked TensorType because we don't know how to
-      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
-      assert(opResult.getType().isa<RankedTensorType>() &&
-             "unsupported unranked tensor");
-
-      Value resultBuffer = getResultBuffer(b, opResult, state);
-      if (!resultBuffer)
-        return failure();
-
-      state.aliasInfo.createAliasInfoEntry(resultBuffer);
-      state.mapBuffer(opResult, resultBuffer);
-    }
-
-    return success();
-  }
-};
-
-struct ForOpInterface
-    : public BufferizableOpInterface::ExternalModel<ForOpInterface,
-                                                    scf::ForOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
-    // its matching bbArg may.
-    auto forOp = cast<scf::ForOp>(op);
-    return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    // Tensor iter_args of scf::ForOps are always considered as a write. This is
-    // to simplify the analysis.
-    // TODO: Consider doing sth. like isValueWritten.
-    return true;
-  }
-
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    auto forOp = cast<scf::ForOp>(op);
-    return {&forOp.getIterOpOperands()[opResult.getResultNumber()]};
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    auto forOp = cast<scf::ForOp>(op);
-    if (!opOperand.get().getType().isa<RankedTensorType>())
-      return OpResult();
-    return forOp.getResultForOpOperand(opOperand);
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
-  bool isWritable(Operation *op, Value value) const {
-    // Interestingly, scf::ForOp's bbArg can **always** be viewed
-    // inplace from the perspective of ops nested under:
-    //   1. Either the matching iter operand is not bufferized inplace and an
-    //      alloc + optional copy makes the bbArg itself inplaceable.
-    //   2. Or the matching iter operand is bufferized inplace and bbArg just
-    //      bufferizes to that too.
-    return true;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto forOp = cast<scf::ForOp>(op);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-
-    for (OpResult opResult : forOp->getResults()) {
-      if (!opResult.getType().isa<TensorType>())
-        continue;
-      // TODO: Atm we bail on unranked TensorType because we don't know how to
-      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
-      assert(opResult.getType().isa<RankedTensorType>() &&
-             "unsupported unranked tensor");
-
-      // TODO: More general: Matching bbArg does not bufferize to a read.
-      Value resultBuffer = getResultBuffer(b, opResult, state);
-      if (!resultBuffer)
-        return failure();
-
-      OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
-      BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
-      state.aliasInfo.createAliasInfoEntry(resultBuffer);
-      state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
-      state.mapBuffer(bbArg, resultBuffer);
-      state.mapBuffer(opResult, resultBuffer);
-    }
-
-    // Bufferize loop body.
-    if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state)))
-      return failure();
-
-    // Finish bufferizing scf::ForOp.
-    auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
-    for (OpOperand &operand : yieldOp->getOpOperands()) {
-      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-      if (!tensorType)
-        continue;
-
-      OpOperand &forOperand = forOp.getOpOperandForResult(
-          forOp->getResult(operand.getOperandNumber()));
-      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      Value yieldedBuffer = state.lookupBuffer(operand.get());
-      Value bbArgBuffer = state.lookupBuffer(bbArg);
-      if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
-                                                         bbArgBuffer)) {
-        // TODO: this could get resolved with copies but it can also turn into
-        // swaps so we need to be careful about order of copies.
-        return yieldOp->emitError()
-               << "Yield operand #" << operand.getOperandNumber()
-               << " does not bufferize to an equivalent buffer to the matching"
-               << " enclosing scf::for operand";
-      }
-
-      // Buffers are equivalent so the work is already done and we just yield
-      // the bbArg so that it later canonicalizes away.
-      operand.set(bbArg);
-    }
-    return success();
-  }
-};
-
-struct YieldOpInterface
-    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
-                                                    scf::YieldOp> {
-  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
-    return true;
-  }
-
-  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
-    return false;
-  }
-
-  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
-    return OpResult();
-  }
-
-  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
-    return BufferRelation::Equivalent;
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto yieldOp = cast<scf::YieldOp>(op);
-    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
-            yieldOp->getParentOp()))
-      return yieldOp->emitError("unsupported scf::YieldOp parent");
-    return success();
-  }
-};
-
-} // namespace scf_ext
-
 namespace std_ext {
 
 struct CallOpInterface
@@ -1767,18 +1505,11 @@ struct ReturnOpInterface
 } // namespace std_ext
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
-  registry.addOpInterface<scf::ExecuteRegionOp,
-                          scf_ext::ExecuteRegionOpInterface>();
-  registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
-  registry.addOpInterface<scf::IfOp, scf_ext::IfOpInterface>();
-  registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
   registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
   registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
 
   // Ops that are not bufferizable but are allocation hoisting barriers.
   registry.addOpInterface<FuncOp, AllocationHoistingBarrierOnly<FuncOp>>();
-  registry.addOpInterface<scf::ParallelOp,
-                          AllocationHoistingBarrierOnly<scf::ParallelOp>>();
   registry.addOpInterface<AffineParallelOp,
                           AllocationHoistingBarrierOnly<AffineParallelOp>>();
 }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
new file mode 100644
index 0000000000000..811824202091c
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -0,0 +1,291 @@
+//===- SCFInterfaceImpl.cpp - SCF Impl. of BufferizableOpInterface --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace scf_ext {
+
+struct ExecuteRegionOpInterface
+    : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
+                                                    scf::ExecuteRegionOp> {
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
+    // any SSA value that is in scope. To allow for use-def chain traversal
+    // through ExecuteRegionOps in the analysis, the corresponding yield value
+    // is considered to be aliasing with the result.
+    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
+    size_t resultNum = std::distance(op->getOpResults().begin(),
+                                     llvm::find(op->getOpResults(), opResult));
+    assert(executeRegionOp.region().getBlocks().size() == 1 &&
+           "expected exactly 1 block");
+    auto yieldOp = dyn_cast<scf::YieldOp>(
+        executeRegionOp.region().front().getTerminator());
+    assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
+    return {&yieldOp->getOpOperand(resultNum)};
+  }
+
+  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+    // ExecuteRegionOp results always bufferize in-place. Since they have no
+    // OpOperands, they are mostly ignored by the analysis once alias sets are
+    // set up.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    // TODO: Add bufferization support when needed. scf.execute_region should be
+    // bufferized similar to scf.if.
+    auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
+    bool hasTensorReturnType = any_of(
+        op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
+    if (hasTensorReturnType)
+      return op->emitError(
+          "scf.execute_region with tensor result not supported");
+    return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state);
+  }
+};
+
+struct IfOpInterface
+    : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    // IfOps do not have tensor OpOperands. The yielded value can be any SSA
+    // value that is in scope. To allow for use-def chain traversal through
+    // IfOps in the analysis, both corresponding yield values from the then/else
+    // branches are considered to be aliasing with the result.
+    auto ifOp = cast<scf::IfOp>(op);
+    size_t resultNum = std::distance(op->getOpResults().begin(),
+                                     llvm::find(op->getOpResults(), opResult));
+    return {&ifOp.thenYield()->getOpOperand(resultNum),
+            &ifOp.elseYield()->getOpOperand(resultNum)};
+  }
+
+  // TODO: For better bufferization results, this could return `true` only if
+  // there is a memory write in one (or both) of the branches. Since this is not
+  // allowed at the moment, we should never encounter scf.ifs that yield
+  // unmodified tensors. Such scf.yield ops could just fold away.
+  bool isMemoryWrite(Operation *op, OpResult opResult) const {
+    // IfOp results are always considered memory writes in the analysis. This
+    // design decision simplifies the analysis considerably. E.g., consider the
+    // following test case:
+    //
+    // %0 = "some_writing_op" : tensor<?xf32>
+    // %r = scf.if %c -> (tensor<?xf32>) {
+    //   scf.yield %0
+    // } else {
+    //   %1 = "another_writing_op"(%0) : tensor<?xf32>
+    // }
+    // "some_reading_op"(%r)
+    //
+    // "another_writing_op" in the above example should be able to bufferize
+    // inplace in the absence of another read of %0. However, if the scf.if op
+    // would not be considered a "write", the analysis would detect the
+    // following conflict:
+    //
+    // * read = some_reading_op
+    // * lastWrite = %0  (Note: The last write of %r would be a set: {%0, %1}.)
+    // * conflictingWrite = %1
+    //
+    // For more details, check the "scf.IfOp" section of the design document.
+    return true;
+  }
+
+  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+    // IfOp results always bufferize in-place. Since they have no OpOperands,
+    // they are mostly ignored by the analysis once alias sets are set up.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto ifOp = cast<scf::IfOp>(op);
+
+    // Bufferize then/else blocks.
+    if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
+      return failure();
+    if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
+      return failure();
+
+    for (OpResult opResult : ifOp->getResults()) {
+      if (!opResult.getType().isa<TensorType>())
+        continue;
+      // TODO: Atm we bail on unranked TensorType because we don't know how to
+      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+      assert(opResult.getType().isa<RankedTensorType>() &&
+             "unsupported unranked tensor");
+
+      Value resultBuffer = getResultBuffer(b, opResult, state);
+      if (!resultBuffer)
+        return failure();
+
+      state.aliasInfo.createAliasInfoEntry(resultBuffer);
+      state.mapBuffer(opResult, resultBuffer);
+    }
+
+    return success();
+  }
+};
+
+struct ForOpInterface
+    : public BufferizableOpInterface::ExternalModel<ForOpInterface,
+                                                    scf::ForOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
+    // its matching bbArg may.
+    auto forOp = cast<scf::ForOp>(op);
+    return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    // Tensor iter_args of scf::ForOps are always considered as a write. This is
+    // to simplify the analysis.
+    // TODO: Consider doing sth. like isValueWritten.
+    return true;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    auto forOp = cast<scf::ForOp>(op);
+    return {&forOp.getIterOpOperands()[opResult.getResultNumber()]};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    auto forOp = cast<scf::ForOp>(op);
+    if (!opOperand.get().getType().isa<RankedTensorType>())
+      return OpResult();
+    return forOp.getResultForOpOperand(opOperand);
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::Equivalent;
+  }
+
+  bool isWritable(Operation *op, Value value) const {
+    // Interestingly, scf::ForOp's bbArg can **always** be viewed
+    // inplace from the perspective of ops nested under:
+    //   1. Either the matching iter operand is not bufferized inplace and an
+    //      alloc + optional copy makes the bbArg itself inplaceable.
+    //   2. Or the matching iter operand is bufferized inplace and bbArg just
+    //      bufferizes to that too.
+    return true;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto forOp = cast<scf::ForOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+
+    for (OpResult opResult : forOp->getResults()) {
+      if (!opResult.getType().isa<TensorType>())
+        continue;
+      // TODO: Atm we bail on unranked TensorType because we don't know how to
+      // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+      assert(opResult.getType().isa<RankedTensorType>() &&
+             "unsupported unranked tensor");
+
+      // TODO: More general: Matching bbArg does not bufferize to a read.
+      Value resultBuffer = getResultBuffer(b, opResult, state);
+      if (!resultBuffer)
+        return failure();
+
+      OpOperand &opOperand = forOp.getOpOperandForResult(opResult);
+      BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand);
+      state.aliasInfo.createAliasInfoEntry(resultBuffer);
+      state.aliasInfo.insertNewBufferEquivalence(bbArg, resultBuffer);
+      state.mapBuffer(bbArg, resultBuffer);
+      state.mapBuffer(opResult, resultBuffer);
+    }
+
+    // Bufferize loop body.
+    if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state)))
+      return failure();
+
+    // Finish bufferizing scf::ForOp.
+    auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
+    for (OpOperand &operand : yieldOp->getOpOperands()) {
+      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
+
+      OpOperand &forOperand = forOp.getOpOperandForResult(
+          forOp->getResult(operand.getOperandNumber()));
+      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+      Value yieldedBuffer = state.lookupBuffer(operand.get());
+      Value bbArgBuffer = state.lookupBuffer(bbArg);
+      if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
+                                                         bbArgBuffer)) {
+        // TODO: this could get resolved with copies but it can also turn into
+        // swaps so we need to be careful about order of copies.
+        return yieldOp->emitError()
+               << "Yield operand #" << operand.getOperandNumber()
+               << " does not bufferize to an equivalent buffer to the matching"
+               << " enclosing scf::for operand";
+      }
+
+      // Buffers are equivalent so the work is already done and we just yield
+      // the bbArg so that it later canonicalizes away.
+      operand.set(bbArg);
+    }
+    return success();
+  }
+};
+
+struct YieldOpInterface
+    : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
+                                                    scf::YieldOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return false;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return OpResult();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::Equivalent;
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto yieldOp = cast<scf::YieldOp>(op);
+    if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
+            yieldOp->getParentOp()))
+      return yieldOp->emitError("unsupported scf::YieldOp parent");
+    return success();
+  }
+};
+
+} // namespace scf_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::scf_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<scf::ExecuteRegionOp,
+                          scf_ext::ExecuteRegionOpInterface>();
+  registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
+  registry.addOpInterface<scf::IfOp, scf_ext::IfOpInterface>();
+  registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
+  registry.addOpInterface<scf::ParallelOp,
+                          AllocationHoistingBarrierOnly<scf::ParallelOp>>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 441ee9ea7e4eb..9732b48f6356d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -44,6 +44,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRLinalgBufferizableOpInterfaceImpl
   MLIRLinalgUtils
   MLIRSCF
+  MLIRSCFBufferizableOpInterfaceImpl
   MLIRSCFTransforms
   MLIRPass
   MLIRStandard

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index dd5dd127bdeaa..d6ab48ae63680 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Passes.h"
@@ -42,6 +43,7 @@ struct LinalgComprehensiveModuleBufferize
     registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
   }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 93fbedc2e890f..f91bb5dc959e9 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6346,6 +6346,24 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "SCFBufferizableOpInterfaceImpl",
+    srcs = [
+        "lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":BufferizableOpInterface",
+        ":IR",
+        ":SCFDialect",
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "TensorBufferizableOpInterfaceImpl",
     srcs = [
@@ -6598,6 +6616,7 @@ cc_library(
         ":MathDialect",
         ":MemRefDialect",
         ":Pass",
+        ":SCFBufferizableOpInterfaceImpl",
         ":SCFDialect",
         ":SCFTransforms",
         ":StandardOps",
@@ -6631,7 +6650,6 @@ cc_library(
         ":InferTypeOpInterface",
         ":MemRefDialect",
         ":Pass",
-        ":SCFDialect",
         ":StandardOps",
         ":Support",
         "//llvm:Support",


        


More information about the Mlir-commits mailing list