[Mlir-commits] [mlir] e448c79 - [mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect

Matthias Springer llvmlistbot at llvm.org
Sun Jan 30 05:12:34 PST 2022


Author: Matthias Springer
Date: 2022-01-30T22:12:14+09:00
New Revision: e448c793c66521ee48d0107c33b80a2ff1baaaaf

URL: https://github.com/llvm/llvm-project/commit/e448c793c66521ee48d0107c33b80a2ff1baaaaf
DIFF: https://github.com/llvm/llvm-project/commit/e448c793c66521ee48d0107c33b80a2ff1baaaaf.diff

LOG: [mlir][bufferize][NFC] Move std BufferizableOpInterfaceImpl to std dialect

Also reimplement `std-bufferize` in terms of BufferizableOpInterface-based bufferization. The old `std.select` bufferization pattern is no longer needed and deleted.

Differential Revision: https://reviews.llvm.org/D118559

Added: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
    mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp

Modified: 
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
    mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
    mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/CMakeLists.txt
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
deleted file mode 100644
index ae3b3db23e648..0000000000000
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h
+++ /dev/null
@@ -1,27 +0,0 @@
-//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
-#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
-
-namespace mlir {
-
-class DialectRegistry;
-
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace std_ext {
-
-void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
-
-} // namespace std_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
-} // namespace mlir
-
-#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000000..a85acbbb195a6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,18 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace mlir
+
+#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index 52bbea000d1f3..d6b8d2028e0e2 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -23,10 +23,6 @@ class BufferizeTypeConverter;
 
 class RewritePatternSet;
 
-void populateStdBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns);
-
 /// Creates an instance of std bufferization pass.
 std::unique_ptr<Pass> createStdBufferizePass();
 

diff  --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
index 6bd83938346e4..3e08865c6f71a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
 def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
   let summary = "Bufferize the std dialect";
   let constructor = "mlir::createStdBufferizePass()";
-  let dependentDialects = ["bufferization::BufferizationDialect",
-                           "memref::MemRefDialect", "scf::SCFDialect"];
 }
 
 def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 10d0f72ebcfe8..5733f88c953ab 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -25,14 +25,6 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   MLIRTensor
 )
 
-add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
-  StdInterfaceImpl.cpp
-
-  LINK_LIBS PUBLIC
-  MLIRBufferization
-  MLIRStandard
-)
-
 add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
   VectorInterfaceImpl.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 5a42474d49da4..d1418c40f035c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -49,7 +49,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRSCF
   MLIRSCFTransforms
   MLIRSCFUtils
-  MLIRStdBufferizableOpInterfaceImpl
   MLIRPass
   MLIRStandard
   MLIRStandardOpsTransforms

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 9d2e6a539a182..f809bf35dc6fe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -15,10 +15,10 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -56,7 +56,7 @@ struct LinalgComprehensiveModuleBufferize
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerModuleBufferizationExternalModels(registry);
-    std_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    mlir::registerBufferizableOpInterfaceExternalModels(registry);
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
similarity index 83%
rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
rename to mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
index 7941c979b09e2..b89a5372a48b6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1,4 +1,4 @@
-//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,19 +6,18 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 
+using namespace mlir;
 using namespace mlir::bufferization;
 
 namespace mlir {
-namespace linalg {
-namespace comprehensive_bufferize {
-namespace std_ext {
+namespace {
 
 /// Bufferization of std.select. Just replace the operands.
 struct SelectOpInterface
@@ -69,12 +68,10 @@ struct SelectOpInterface
   }
 };
 
-} // namespace std_ext
-} // namespace comprehensive_bufferize
-} // namespace linalg
+} // namespace
 } // namespace mlir
 
-void mlir::linalg::comprehensive_bufferize::std_ext::
-    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
-  registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
+void mlir::registerBufferizableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addOpInterface<SelectOp, SelectOpInterface>();
 }

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index 9851784d4951c..64f9d040a71ca 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -12,64 +12,34 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "PassDetail.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
 
 using namespace mlir;
-
-namespace {
-class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
-public:
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(SelectOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!op.getCondition().getType().isa<IntegerType>())
-      return rewriter.notifyMatchFailure(op, "requires scalar condition");
-
-    rewriter.replaceOpWithNewOp<SelectOp>(op, adaptor.getCondition(),
-                                          adaptor.getTrueValue(),
-                                          adaptor.getFalseValue());
-    return success();
-  }
-};
-} // namespace
-
-void mlir::populateStdBufferizePatterns(
-    bufferization::BufferizeTypeConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
-}
+using namespace mlir::bufferization;
 
 namespace {
 struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
   void runOnOperation() override {
-    auto *context = &getContext();
-    bufferization::BufferizeTypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
-                           memref::MemRefDialect>();
+    std::unique_ptr<BufferizationOptions> options =
+        getPartialBufferizationOptions();
+    options->addToDialectFilter<StandardOpsDialect>();
 
-    populateStdBufferizePatterns(typeConverter, patterns);
-    // We only bufferize the case of tensor selected type and scalar condition,
-    // as that boils down to a select over memref descriptors (don't need to
-    // touch the data).
-    target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
-      return typeConverter.isLegal(op.getType()) ||
-             !op.getCondition().getType().isa<IntegerType>();
-    });
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(bufferizeOp(getOperation(), *options)))
       signalPassFailure();
   }
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
+                    StandardOpsDialect, scf::SCFDialect>();
+    mlir::registerBufferizableOpInterfaceExternalModels(registry);
+  }
 };
 } // namespace
 

diff  --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
index d5869ce207cf8..7db425fdc361d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
+  BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   DecomposeCallGraphTypes.cpp
   FuncBufferize.cpp
@@ -13,6 +14,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
   LINK_LIBS PUBLIC
   MLIRAffine
   MLIRArithmeticTransforms
+  MLIRBufferization
   MLIRBufferizationTransforms
   MLIRIR
   MLIRMemRef

diff  --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index 7ad34948bcebd..4fa6070921806 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -27,8 +27,8 @@ add_mlir_library(MLIRLinalgTestPasses
   MLIRPass
   MLIRSCF
   MLIRSCFTransforms
-  MLIRStdBufferizableOpInterfaceImpl
   MLIRStandard
+  MLIRStandardOpsTransforms
   MLIRTensor
   MLIRTensorTransforms
   MLIRTransformUtils

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 5cde7cf2ac094..b074043b7a3ab 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -18,12 +18,12 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Pass/PassManager.h"
@@ -62,7 +62,7 @@ struct TestComprehensiveFunctionBufferize
     arith::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf::registerBufferizableOpInterfaceExternalModels(registry);
-    std_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    mlir::registerBufferizableOpInterfaceExternalModels(registry);
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
   }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 185998a0bf947..e786e82d4f5e6 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6658,24 +6658,6 @@ cc_library(
     ],
 )
 
-cc_library(
-    name = "StdBufferizableOpInterfaceImpl",
-    srcs = [
-        "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp",
-    ],
-    hdrs = [
-        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h",
-    ],
-    includes = ["include"],
-    deps = [
-        ":BufferizationDialect",
-        ":IR",
-        ":StandardOps",
-        ":Support",
-        "//llvm:Support",
-    ],
-)
-
 cc_library(
     name = "VectorBufferizableOpInterfaceImpl",
     srcs = [
@@ -6916,7 +6898,6 @@ cc_library(
         ":SCFUtils",
         ":StandardOps",
         ":StandardOpsTransforms",
-        ":StdBufferizableOpInterfaceImpl",
         ":Support",
         ":TensorDialect",
         ":TensorTransforms",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index d23ca654fd09b..4292c258c051c 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -403,7 +403,7 @@ cc_library(
         "//mlir:SCFDialect",
         "//mlir:SCFTransforms",
         "//mlir:StandardOps",
-        "//mlir:StdBufferizableOpInterfaceImpl",
+        "//mlir:StandardOpsTransforms",
         "//mlir:TensorDialect",
         "//mlir:TensorTransforms",
         "//mlir:TransformUtils",


        


More information about the Mlir-commits mailing list