[Mlir-commits] [mlir] d3bb4fe - [mlir][linalg][bufferize][NFC] Move arith interface impl to new build target

Matthias Springer llvmlistbot at llvm.org
Wed Nov 24 17:21:32 PST 2021


Author: Matthias Springer
Date: 2021-11-25T10:21:02+09:00
New Revision: d3bb4fec2a5a698773b6b4a3758f166113aa2b8c

URL: https://github.com/llvm/llvm-project/commit/d3bb4fec2a5a698773b6b4a3758f166113aa2b8c
DIFF: https://github.com/llvm/llvm-project/commit/d3bb4fec2a5a698773b6b4a3758f166113aa2b8c.diff

LOG: [mlir][linalg][bufferize][NFC] Move arith interface impl to new build target

This makes ComprehensiveBufferize entirely independent of the arith dialect.

Differential Revision: https://reviews.llvm.org/D114219

Added: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h
new file mode 100644
index 0000000000000..6a2139d98c7c9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace arith_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace arith_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
new file mode 100644
index 0000000000000..ec69832bcf0dd
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -0,0 +1,73 @@
+//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/BufferUtils.h"
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace arith_ext {
+
+struct ConstantOpInterface
+    : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
+                                                    arith::ConstantOp> {
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto constantOp = cast<arith::ConstantOp>(op);
+    if (!constantOp.getResult().getType().isa<TensorType>())
+      return success();
+    assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
+           "not a constant ranked tensor");
+    auto moduleOp = constantOp->getParentOfType<ModuleOp>();
+    if (!moduleOp) {
+      return constantOp.emitError(
+          "cannot bufferize constants not within builtin.module op");
+    }
+    GlobalCreator globalCreator(moduleOp);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(constantOp);
+
+    auto globalMemref = globalCreator.getGlobalFor(constantOp);
+    Value memref = b.create<memref::GetGlobalOp>(
+        constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
+    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
+    state.mapBuffer(constantOp, memref);
+
+    return success();
+  }
+
+  bool isWritable(Operation *op, Value value) const {
+    // Memory locations returned by memref::GetGlobalOp may not be written to.
+    assert(value.isa<OpResult>());
+    return false;
+  }
+};
+
+} // namespace arith_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::arith_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index e8bbc9a33c22f..b94694a3e53cd 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(LLVM_OPTIONAL_SOURCES
+  ArithInterfaceImpl.cpp
   BufferizableOpInterface.cpp
   ComprehensiveBufferize.cpp
   LinalgInterfaceImpl.cpp
@@ -17,6 +18,17 @@ add_mlir_dialect_library(MLIRBufferizableOpInterface
   MLIRMemRef
 )
 
+add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
+  ArithInterfaceImpl.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRArithmetic
+  MLIRBufferizableOpInterface
+  MLIRIR
+  MLIRMemRef
+  MLIRStandardOpsTransforms
+)
+
 add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
   LinalgInterfaceImpl.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 22f5bf3b06e52..c07adcf3fbc03 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -116,16 +116,17 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/BufferUtils.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 
 #define DEBUG_TYPE "comprehensive-module-bufferize"
@@ -1287,52 +1288,6 @@ BufferizationOptions::BufferizationOptions()
 namespace mlir {
 namespace linalg {
 namespace comprehensive_bufferize {
-namespace arith_ext {
-
-struct ConstantOpInterface
-    : public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
-                                                    arith::ConstantOp> {
-  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
-                                                OpResult opResult) const {
-    return {};
-  }
-
-  LogicalResult bufferize(Operation *op, OpBuilder &b,
-                          BufferizationState &state) const {
-    auto constantOp = cast<arith::ConstantOp>(op);
-    if (!isaTensor(constantOp.getResult().getType()))
-      return success();
-    assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
-           "not a constant ranked tensor");
-    auto moduleOp = constantOp->getParentOfType<ModuleOp>();
-    if (!moduleOp) {
-      return constantOp.emitError(
-          "cannot bufferize constants not within builtin.module op");
-    }
-    GlobalCreator globalCreator(moduleOp);
-
-    // Take a guard before anything else.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(constantOp);
-
-    auto globalMemref = globalCreator.getGlobalFor(constantOp);
-    Value memref = b.create<memref::GetGlobalOp>(
-        constantOp.getLoc(), globalMemref.type(), globalMemref.getName());
-    state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult());
-    state.mapBuffer(constantOp, memref);
-
-    return success();
-  }
-
-  bool isWritable(Operation *op, Value value) const {
-    // Memory locations returned by memref::GetGlobalOp may not be written to.
-    assert(value.isa<OpResult>());
-    return false;
-  }
-};
-
-} // namespace arith_ext
-
 namespace scf_ext {
 
 struct ExecuteRegionOpInterface
@@ -1813,7 +1768,6 @@ struct ReturnOpInterface
 } // namespace std_ext
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
-  registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
   registry.addOpInterface<scf::ExecuteRegionOp,
                           scf_ext::ExecuteRegionOpInterface>();
   registry.addOpInterface<scf::ForOp, scf_ext::ForOpInterface>();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 185068ee2ad61..441ee9ea7e4eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRAffine
   MLIRAffineUtils
   MLIRAnalysis
+  MLIRArithBufferizableOpInterfaceImpl
   MLIRArithmetic
   MLIRBufferizableOpInterface
   MLIRComplex

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 9acb2c1d5fc3d..dd5dd127bdeaa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
@@ -39,6 +40,7 @@ struct LinalgComprehensiveModuleBufferize
                 tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
                 arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
     registerBufferizableOpInterfaceExternalModels(registry);
+    arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
     vector_ext::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 40269f27f6b2d..93fbedc2e890f 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6306,6 +6306,26 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArithBufferizableOpInterfaceImpl",
+    srcs = [
+        "lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
+    ],
+    hdrs = [
+        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":ArithmeticDialect",
+        ":BufferizableOpInterface",
+        ":IR",
+        ":MemRefDialect",
+        ":Support",
+        ":TransformUtils",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "LinalgBufferizableOpInterfaceImpl",
     srcs = [
@@ -6563,6 +6583,7 @@ cc_library(
         ":Affine",
         ":AffineUtils",
         ":Analysis",
+        ":ArithBufferizableOpInterfaceImpl",
         ":ArithmeticDialect",
         ":BufferizableOpInterface",
         ":ComplexDialect",
@@ -6604,7 +6625,6 @@ cc_library(
     includes = ["include"],
     deps = [
         ":Affine",
-        ":ArithmeticDialect",
         ":BufferizableOpInterface",
         ":DialectUtils",
         ":IR",
@@ -6614,7 +6634,6 @@ cc_library(
         ":SCFDialect",
         ":StandardOps",
         ":Support",
-        ":TransformUtils",
         "//llvm:Support",
     ],
 )


        


More information about the Mlir-commits mailing list