[Mlir-commits] [mlir] 19efe14 - [mlir][bufferize][NFC] Move scf BufferizableOpInterface impl to scf dialect

Matthias Springer llvmlistbot at llvm.org
Sun Jan 30 04:58:20 PST 2022


Author: Matthias Springer
Date: 2022-01-30T21:53:33+09:00
New Revision: 19efe141f72bdba67fa4e07e212ae65412533338

URL: https://github.com/llvm/llvm-project/commit/19efe141f72bdba67fa4e07e212ae65412533338
DIFF: https://github.com/llvm/llvm-project/commit/19efe141f72bdba67fa4e07e212ae65412533338.diff

LOG: [mlir][bufferize][NFC] Move scf BufferizableOpInterface impl to scf dialect

Differential Revision: https://reviews.llvm.org/D118557

Added: 
    mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
similarity index 69%
rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
rename to mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
index afea3fae490dd..ea8969e004686 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
@@ -1,4 +1,4 @@
-//===- SCFInterfaceImpl.h - SCF Impl. of BufferizableOpInterface ----------===//
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,19 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
+#ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
 
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 
 namespace mlir {
-
 class DialectRegistry;
 
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace scf_ext {
-
+namespace scf {
 /// Assert that yielded values of an scf.for op are aliasing their corresponding
 /// bbArgs. This is required because the i-th OpResult of an scf.for op is
 /// currently assumed to alias with the i-th iter_arg (in the absence of
@@ -30,10 +26,7 @@ struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
 };
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
-
-} // namespace scf_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace scf
 } // namespace mlir
 
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
+#endif // MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index caddad30c6151..10d0f72ebcfe8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,7 +2,6 @@ set(LLVM_OPTIONAL_SOURCES
   AffineInterfaceImpl.cpp
   LinalgInterfaceImpl.cpp
   ModuleBufferization.cpp
-  SCFInterfaceImpl.cpp
   StdInterfaceImpl.cpp
   VectorInterfaceImpl.cpp
 )
@@ -26,16 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   MLIRTensor
 )
 
-add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
-  SCFInterfaceImpl.cpp
-
-  LINK_LIBS PUBLIC
-  MLIRBufferization
-  MLIRBufferizationTransforms
-  MLIRIR
-  MLIRSCF
-)
-
 add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
   StdInterfaceImpl.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 0d1874938c5fe..5a42474d49da4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -47,7 +47,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRLinalgUtils
   MLIRModuleBufferization
   MLIRSCF
-  MLIRSCFBufferizableOpInterfaceImpl
   MLIRSCFTransforms
   MLIRSCFUtils
   MLIRStdBufferizableOpInterfaceImpl

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 314daed4f4cef..9d2e6a539a182 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -15,10 +15,10 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -54,7 +54,7 @@ struct LinalgComprehensiveModuleBufferize
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
-    scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    scf::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerModuleBufferizationExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
@@ -132,7 +132,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   }
 
   // Only certain scf.for ops are supported by the analysis.
-  options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
+  options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
similarity index 94%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
rename to mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index cc8517f8f119f..a622a17080efc 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- SCFInterfaceImpl.cpp - SCF Impl. of BufferizableOpInterface --------===//
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
+#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
+
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/SCF/SCF.h"
@@ -14,12 +15,13 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
 
+using namespace mlir;
 using namespace mlir::bufferization;
+using namespace mlir::scf;
 
 namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace scf_ext {
+namespace scf {
+namespace {
 
 // bufferization.to_memref is not allowed to change the rank.
 static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
@@ -384,42 +386,6 @@ struct ForOpInterface
   }
 };
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
-    run(Operation *op, BufferizationState &state,
-        BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
-  LogicalResult status = success();
-
-  op->walk([&](scf::ForOp forOp) {
-    auto yieldOp =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
-    for (OpOperand &operand : yieldOp->getOpOperands()) {
-      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
-      if (!tensorType)
-        continue;
-
-      OpOperand &forOperand = forOp.getOpOperandForResult(
-          forOp->getResult(operand.getOperandNumber()));
-      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      // Note: This is overly strict. We should check for aliasing bufferized
-      // values. But we don't have a "must-alias" analysis yet.
-      if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
-        // TODO: this could get resolved with copies but it can also turn into
-        // swaps so we need to be careful about order of copies.
-        status =
-            yieldOp->emitError()
-            << "Yield operand #" << operand.getOperandNumber()
-            << " does not bufferize to a buffer that is aliasing the matching"
-            << " enclosing scf::for operand";
-        return WalkResult::interrupt();
-      }
-    }
-    return WalkResult::advance();
-  });
-
-  return status;
-}
-
 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
 /// this is for analysis only.
 struct YieldOpInterface
@@ -462,18 +428,51 @@ struct YieldOpInterface
   }
 };
 
-} // namespace scf_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace
+} // namespace scf
 } // namespace mlir
 
-void mlir::linalg::comprehensive_bufferize::scf_ext::
-    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
-  registry.addOpInterface<scf::ExecuteRegionOp,
-                          scf_ext::ExecuteRegionOpInterface>();
-  registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();
-  registry.addOpInterface<scf::IfOp, scf_ext::IfOpInterface>();
-  registry.addOpInterface<scf::YieldOp, scf_ext::YieldOpInterface>();
-  registry.addOpInterface<scf::ParallelOp,
-                          AllocationHoistingBarrierOnly<scf::ParallelOp>>();
+LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
+    Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
+    SmallVector<Operation *> &newOps) {
+  LogicalResult status = success();
+
+  op->walk([&](scf::ForOp forOp) {
+    auto yieldOp =
+        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    for (OpOperand &operand : yieldOp->getOpOperands()) {
+      auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+      if (!tensorType)
+        continue;
+
+      OpOperand &forOperand = forOp.getOpOperandForResult(
+          forOp->getResult(operand.getOperandNumber()));
+      auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+      // Note: This is overly strict. We should check for aliasing bufferized
+      // values. But we don't have a "must-alias" analysis yet.
+      if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
+        // TODO: this could get resolved with copies but it can also turn into
+        // swaps so we need to be careful about order of copies.
+        status =
+            yieldOp->emitError()
+            << "Yield operand #" << operand.getOperandNumber()
+            << " does not bufferize to a buffer that is aliasing the matching"
+            << " enclosing scf::for operand";
+        return WalkResult::interrupt();
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return status;
+}
+
+void mlir::scf::registerBufferizableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
+  registry.addOpInterface<ForOp, ForOpInterface>();
+  registry.addOpInterface<IfOp, IfOpInterface>();
+  registry.addOpInterface<YieldOp, YieldOpInterface>();
+  registry
+      .addOpInterface<ParallelOp, AllocationHoistingBarrierOnly<ParallelOp>>();
 }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 95483b140d3f3..4858954343c31 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRSCFTransforms
+  BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   ForToWhile.cpp
   LoopCanonicalization.cpp
@@ -20,6 +21,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   MLIRAffine
   MLIRAffineAnalysis
   MLIRArithmetic
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRDialectUtils
   MLIRIR

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 701aadc25d5e5..7ad34948bcebd 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -26,7 +26,7 @@ add_mlir_library(MLIRLinalgTestPasses
   MLIRMemRef
   MLIRPass
   MLIRSCF
-  MLIRSCFBufferizableOpInterfaceImpl
+  MLIRSCFTransforms
   MLIRStdBufferizableOpInterfaceImpl
   MLIRStandard
   MLIRTensor

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index fe2698c63b40d..5cde7cf2ac094 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -18,11 +18,11 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
@@ -61,7 +61,7 @@ struct TestComprehensiveFunctionBufferize
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
-    scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    scf::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
@@ -106,7 +106,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
   auto options = std::make_unique<AnalysisBufferizationOptions>();
 
   if (!allowReturnMemref)
-    options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();
+    options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
 
   options->allowReturnMemref = allowReturnMemref;
   options->allowUnknownOps = allowUnknownOps;

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4a7e104d354a3..185998a0bf947 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1782,6 +1782,7 @@ cc_library(
         "lib/Dialect/SCF/Transforms/*.h",
     ]),
     hdrs = [
+        "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
         "include/mlir/Dialect/SCF/Passes.h",
         "include/mlir/Dialect/SCF/Transforms.h",
     ],
@@ -2435,6 +2436,7 @@ cc_library(
             "include/mlir/Dialect/SCF/*.h",
         ],
         exclude = [
+            "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h",
             "include/mlir/Dialect/SCF/Transforms.h",
         ],
     ),
@@ -6656,25 +6658,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "SCFBufferizableOpInterfaceImpl",
-    srcs = [
-        "lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":BufferizationDialect",
-        ":BufferizationTransforms",
-        ":IR",
-        ":SCFDialect",
-        ":Support",
-        "//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "StdBufferizableOpInterfaceImpl",
     srcs = [
@@ -6928,7 +6911,6 @@ cc_library(
         ":MemRefDialect",
         ":ModuleBufferization",
         ":Pass",
-        ":SCFBufferizableOpInterfaceImpl",
         ":SCFDialect",
         ":SCFTransforms",
         ":SCFUtils",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 6345343440c39..d23ca654fd09b 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -400,7 +400,6 @@ cc_library(
         "//mlir:LinalgTransforms",
         "//mlir:MemRefDialect",
         "//mlir:Pass",
-        "//mlir:SCFBufferizableOpInterfaceImpl",
         "//mlir:SCFDialect",
         "//mlir:SCFTransforms",
         "//mlir:StandardOps",


        


More information about the Mlir-commits mailing list