[Mlir-commits] [mlir] 4b27825 - [mlir-opt] Support parsing operations other than 'builtin.module' as top-level

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 27 18:13:55 PDT 2022


Author: rkayaith
Date: 2022-09-27T21:13:47-04:00
New Revision: 4b27825ba36d55779561c0a2c3c2f89f52d81303

URL: https://github.com/llvm/llvm-project/commit/4b27825ba36d55779561c0a2c3c2f89f52d81303
DIFF: https://github.com/llvm/llvm-project/commit/4b27825ba36d55779561c0a2c3c2f89f52d81303.diff

LOG: [mlir-opt] Support parsing operations other than 'builtin.module' as top-level

This adds a `--no-implicit-module` option, which disables the insertion
of a top-level `builtin.module` during parsing. In this mode any op may
be top-level, however it's required that there be exactly one top-level
op in the source.

`parseSource{File,String}` now support `Operation *` as the container op
type, which disables the top-level-op-insertion behaviour.

Following patches will add the same option to the other tools as well.

Depends on D133644

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D133645

Added: 
    mlir/include/mlir/Tools/ParseUtilties.h
    mlir/test/IR/top-level.mlir
    mlir/test/Pass/pipeline-invalid.mlir

Modified: 
    mlir/include/mlir/IR/OwningOpRef.h
    mlir/include/mlir/Parser/Parser.h
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OwningOpRef.h b/mlir/include/mlir/IR/OwningOpRef.h
index e509800c83aaa..378722c4d9962 100644
--- a/mlir/include/mlir/IR/OwningOpRef.h
+++ b/mlir/include/mlir/IR/OwningOpRef.h
@@ -29,7 +29,7 @@ class OwningOpRef {
   /// The underlying operation type stored in this reference.
   using OperationT = OpTy;
 
-  OwningOpRef(std::nullptr_t = nullptr) {}
+  OwningOpRef(std::nullptr_t = nullptr) : op(nullptr) {}
   OwningOpRef(OpTy op) : op(op) {}
   OwningOpRef(OwningOpRef &&other) : op(other.release()) {}
   ~OwningOpRef() {
@@ -53,7 +53,7 @@ class OwningOpRef {
 
   /// Release the referenced op.
   OpTy release() {
-    OpTy released;
+    OpTy released(nullptr);
     std::swap(released, op);
     return released;
   }

diff  --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index beec5037e5296..f014b2bf7e62e 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -37,38 +37,48 @@ namespace detail {
 template <typename ContainerOpT>
 inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
     Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) {
-  static_assert(
-      ContainerOpT::template hasTrait<OpTrait::OneRegion>() &&
-          (ContainerOpT::template hasTrait<OpTrait::NoTerminator>() ||
-           OpTrait::template hasSingleBlockImplicitTerminator<
-               ContainerOpT>::value),
-      "Expected `ContainerOpT` to have a single region with a single "
-      "block that has an implicit terminator or does not require one");
 
   // Check to see if we parsed a single instance of this operation.
   if (llvm::hasSingleElement(*parsedBlock)) {
-    if (ContainerOpT op = dyn_cast<ContainerOpT>(parsedBlock->front())) {
+    if (ContainerOpT op = dyn_cast<ContainerOpT>(&parsedBlock->front())) {
       op->remove();
       return op;
     }
   }
 
-  // If not, then build a new one to contain the parsed operations.
-  OpBuilder builder(context);
-  ContainerOpT op = builder.create<ContainerOpT>(sourceFileLoc);
-  OwningOpRef<ContainerOpT> opRef(op);
-  assert(op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)) &&
-         "expected generated operation to have a single region with a single "
-         "block");
-  Block *opBlock = &op->getRegion(0).front();
-  opBlock->getOperations().splice(opBlock->begin(),
-                                  parsedBlock->getOperations());
-
-  // After splicing, verify just this operation to ensure it can properly
-  // contain the operations inside of it.
-  if (failed(op.verifyInvariants()))
-    return OwningOpRef<ContainerOpT>();
-  return opRef;
+  // If not, then build a new top-level op if a concrete operation type was
+  // specified.
+  if constexpr (std::is_same_v<ContainerOpT, Operation *>) {
+    return emitError(sourceFileLoc)
+               << "source must contain a single top-level operation, found: "
+               << parsedBlock->getOperations().size(),
+           nullptr;
+  } else {
+    static_assert(
+        ContainerOpT::template hasTrait<OpTrait::OneRegion>() &&
+            (ContainerOpT::template hasTrait<OpTrait::NoTerminator>() ||
+             OpTrait::template hasSingleBlockImplicitTerminator<
+                 ContainerOpT>::value),
+        "Expected `ContainerOpT` to have a single region with a single "
+        "block that has an implicit terminator or does not require one");
+
+    OpBuilder builder(context);
+    ContainerOpT op = builder.create<ContainerOpT>(sourceFileLoc);
+    OwningOpRef<ContainerOpT> opRef(op);
+    assert(op->getNumRegions() == 1 &&
+           llvm::hasSingleElement(op->getRegion(0)) &&
+           "expected generated operation to have a single region with a single "
+           "block");
+    Block *opBlock = &op->getRegion(0).front();
+    opBlock->getOperations().splice(opBlock->begin(),
+                                    parsedBlock->getOperations());
+
+    // After splicing, verify just this operation to ensure it can properly
+    // contain the operations inside of it.
+    if (failed(op.verifyInvariants()))
+      return OwningOpRef<ContainerOpT>();
+    return opRef;
+  }
 }
 } // namespace detail
 
@@ -141,7 +151,7 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(const ParserConfig &config,
 /// failure is returned. `ContainerOpT` is required to have a single region
 /// containing a single block, and must implement the
 /// `SingleBlockImplicitTerminator` trait.
-template <typename ContainerOpT>
+template <typename ContainerOpT = Operation *>
 inline OwningOpRef<ContainerOpT>
 parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) {
   return detail::parseSourceFile<ContainerOpT>(config, sourceMgr);
@@ -155,7 +165,7 @@ parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) {
 /// failure is returned. `ContainerOpT` is required to have a single region
 /// containing a single block, and must implement the
 /// `SingleBlockImplicitTerminator` trait.
-template <typename ContainerOpT>
+template <typename ContainerOpT = Operation *>
 inline OwningOpRef<ContainerOpT> parseSourceFile(StringRef filename,
                                                  const ParserConfig &config) {
   return detail::parseSourceFile<ContainerOpT>(config, filename);
@@ -169,7 +179,7 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(StringRef filename,
 /// registered in the context, and failure is returned. `ContainerOpT` is
 /// required to have a single region containing a single block, and must
 /// implement the `SingleBlockImplicitTerminator` trait.
-template <typename ContainerOpT>
+template <typename ContainerOpT = Operation *>
 inline OwningOpRef<ContainerOpT> parseSourceFile(llvm::StringRef filename,
                                                  llvm::SourceMgr &sourceMgr,
                                                  const ParserConfig &config) {
@@ -184,7 +194,7 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(llvm::StringRef filename,
 /// failure is returned. `ContainerOpT` is required to have a single region
 /// containing a single block, and must implement the
 /// `SingleBlockImplicitTerminator` trait.
-template <typename ContainerOpT>
+template <typename ContainerOpT = Operation *>
 inline OwningOpRef<ContainerOpT> parseSourceString(llvm::StringRef sourceStr,
                                                    const ParserConfig &config) {
   LocationAttr sourceFileLoc;

diff  --git a/mlir/include/mlir/Tools/ParseUtilties.h b/mlir/include/mlir/Tools/ParseUtilties.h
new file mode 100644
index 0000000000000..98e2ba049b9cf
--- /dev/null
+++ b/mlir/include/mlir/Tools/ParseUtilties.h
@@ -0,0 +1,39 @@
+//===- ParseUtilities.h - MLIR Tool Parse Utilities -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file containts common utilities for implementing the file-parsing
+// behaviour for MLIR tools.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PARSEUTILITIES_H
+#define MLIR_TOOLS_PARSEUTILITIES_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Parser/Parser.h"
+
+namespace mlir {
+/// This parses the file specified by the indicated SourceMgr. If parsing was
+/// not successful, null is returned and an error message is emitted through the
+/// error handler registered in the context.
+/// If 'insertImplicitModule' is true a top-level 'builtin.module' op will be
+/// inserted that contains the parsed IR, unless one exists already.
+inline OwningOpRef<Operation *>
+parseSourceFileForTool(llvm::SourceMgr &sourceMgr, const ParserConfig &config,
+                       bool insertImplicitModule) {
+  if (insertImplicitModule) {
+    // TODO: Move implicit module logic out of 'parseSourceFile' and into here.
+    return parseSourceFile<ModuleOp>(sourceMgr, config)
+        .release()
+        .getOperation();
+  }
+  return parseSourceFile(sourceMgr, config);
+}
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PARSEUTILITIES_H

diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index e32c7e786d3b0..1d710f430b25c 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -51,26 +51,24 @@ using PassPipelineFn = llvm::function_ref<LogicalResult(PassManager &pm)>;
 ///   dialects from the global registry in the MLIRContext. This option is
 ///   deprecated and will be removed soon.
 /// - emitBytecode will generate bytecode output instead of text.
-LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
-                          std::unique_ptr<llvm::MemoryBuffer> buffer,
-                          const PassPipelineCLParser &passPipeline,
-                          DialectRegistry &registry, bool splitInputFile,
-                          bool verifyDiagnostics, bool verifyPasses,
-                          bool allowUnregisteredDialects,
-                          bool preloadDialectsInContext = false,
-                          bool emitBytecode = false);
+/// - implicitModule will enable implicit addition of a top-level
+/// 'builtin.module' if one doesn't already exist.
+LogicalResult MlirOptMain(
+    llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
+    const PassPipelineCLParser &passPipeline, DialectRegistry &registry,
+    bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
+    bool allowUnregisteredDialects, bool preloadDialectsInContext = false,
+    bool emitBytecode = false, bool implicitModule = false);
 
 /// Support a callback to setup the pass manager.
 /// - passManagerSetupFn is the callback invoked to setup the pass manager to
 ///   apply on the loaded IR.
-LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
-                          std::unique_ptr<llvm::MemoryBuffer> buffer,
-                          PassPipelineFn passManagerSetupFn,
-                          DialectRegistry &registry, bool splitInputFile,
-                          bool verifyDiagnostics, bool verifyPasses,
-                          bool allowUnregisteredDialects,
-                          bool preloadDialectsInContext = false,
-                          bool emitBytecode = false);
+LogicalResult MlirOptMain(
+    llvm::raw_ostream &outputStream, std::unique_ptr<llvm::MemoryBuffer> buffer,
+    PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
+    bool splitInputFile, bool verifyDiagnostics, bool verifyPasses,
+    bool allowUnregisteredDialects, bool preloadDialectsInContext = false,
+    bool emitBytecode = false, bool implicitModule = false);
 
 /// Implementation for tools like `mlir-opt`.
 /// - toolName is used for the header displayed by `--help`.

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 043a90a79a052..5d02d853b7d98 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -27,6 +27,7 @@
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Support/Timing.h"
 #include "mlir/Support/ToolUtilities.h"
+#include "mlir/Tools/ParseUtilties.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
 #include "llvm/Support/InitLLVM.h"
@@ -49,7 +50,7 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
                                     bool verifyPasses, SourceMgr &sourceMgr,
                                     MLIRContext *context,
                                     PassPipelineFn passManagerSetupFn,
-                                    bool emitBytecode) {
+                                    bool emitBytecode, bool implicitModule) {
   DefaultTimingManager tm;
   applyDefaultTimingManagerCLOptions(tm);
   TimingScope timing = tm.getRootScope();
@@ -70,15 +71,16 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
 
   // Parse the input file and reset the context threading state.
   TimingScope parserTiming = timing.nest("Parser");
-  OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, config));
+  OwningOpRef<Operation *> op =
+      parseSourceFileForTool(sourceMgr, config, implicitModule);
   context->enableMultithreading(wasThreadingEnabled);
-  if (!module)
+  if (!op)
     return failure();
   parserTiming.stop();
 
   // Prepare the pass manager, applying command-line and reproducer options.
   PassManager pm(context, OpPassManager::Nesting::Implicit,
-                 module->getOperationName());
+                 op.get()->getName().getStringRef());
   pm.enableVerifier(verifyPasses);
   applyPassManagerCLOptions(pm);
   pm.enableTiming(timing);
@@ -86,18 +88,18 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
     return failure();
 
   // Run the pipeline.
-  if (failed(pm.run(*module)))
+  if (failed(pm.run(*op)))
     return failure();
 
   // Print the output.
   TimingScope outputTiming = timing.nest("Output");
   if (emitBytecode) {
     BytecodeWriterConfig writerConfig(fallbackResourceMap);
-    writeBytecodeToFile(module->getOperation(), os, writerConfig);
+    writeBytecodeToFile(op.get(), os, writerConfig);
   } else {
-    AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr,
+    AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
                       &fallbackResourceMap);
-    module->print(os, asmState);
+    op.get()->print(os, asmState);
     os << '\n';
   }
   return success();
@@ -109,8 +111,9 @@ static LogicalResult
 processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
               bool verifyDiagnostics, bool verifyPasses,
               bool allowUnregisteredDialects, bool preloadDialectsInContext,
-              bool emitBytecode, PassPipelineFn passManagerSetupFn,
-              DialectRegistry &registry, llvm::ThreadPool *threadPool) {
+              bool emitBytecode, bool implicitModule,
+              PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
+              llvm::ThreadPool *threadPool) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -134,7 +137,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
   if (!verifyDiagnostics) {
     SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
     return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
-                          &context, passManagerSetupFn, emitBytecode);
+                          &context, passManagerSetupFn, emitBytecode,
+                          implicitModule);
   }
 
   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -143,7 +147,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
   // these actions succeed or fail, we only care what diagnostics they produce
   // and whether they match our expectations.
   (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
-                       passManagerSetupFn, emitBytecode);
+                       passManagerSetupFn, emitBytecode, implicitModule);
 
   // Verify the diagnostic handler to make sure that each of the diagnostics
   // matched.
@@ -157,7 +161,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
                                 bool verifyDiagnostics, bool verifyPasses,
                                 bool allowUnregisteredDialects,
                                 bool preloadDialectsInContext,
-                                bool emitBytecode) {
+                                bool emitBytecode, bool implicitModule) {
   // The split-input-file mode is a very specific mode that slices the file
   // up into small pieces and checks each independently.
   // We use an explicit threadpool to avoid creating and joining/destroying
@@ -176,7 +180,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
                      raw_ostream &os) {
     return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
                          verifyPasses, allowUnregisteredDialects,
-                         preloadDialectsInContext, emitBytecode,
+                         preloadDialectsInContext, emitBytecode, implicitModule,
                          passManagerSetupFn, registry, threadPool);
   };
   return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
@@ -190,7 +194,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
                                 bool verifyDiagnostics, bool verifyPasses,
                                 bool allowUnregisteredDialects,
                                 bool preloadDialectsInContext,
-                                bool emitBytecode) {
+                                bool emitBytecode, bool implicitModule) {
   auto passManagerSetupFn = [&](PassManager &pm) {
     auto errorHandler = [&](const Twine &msg) {
       emitError(UnknownLoc::get(pm.getContext())) << msg;
@@ -201,7 +205,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
   return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
                      registry, splitInputFile, verifyDiagnostics, verifyPasses,
                      allowUnregisteredDialects, preloadDialectsInContext,
-                     emitBytecode);
+                     emitBytecode, implicitModule);
 }
 
 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -243,6 +247,12 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
       "emit-bytecode", cl::desc("Emit bytecode when generating output"),
       cl::init(false));
 
+  static cl::opt<bool> noImplicitModule{
+      "no-implicit-module",
+      cl::desc(
+          "Disable implicit addition of a top-level module op during parsing"),
+      cl::init(false)};
+
   InitLLVM y(argc, argv);
 
   // Register any command line options.
@@ -288,7 +298,7 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
   if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
                          splitInputFile, verifyDiagnostics, verifyPasses,
                          allowUnregisteredDialects, preloadDialectsInContext,
-                         emitBytecode)))
+                         emitBytecode, /*implicitModule=*/!noImplicitModule)))
     return failure();
 
   // Keep the output file if the invocation of MlirOptMain was successful.

diff  --git a/mlir/test/IR/top-level.mlir b/mlir/test/IR/top-level.mlir
new file mode 100644
index 0000000000000..b571d944928c8
--- /dev/null
+++ b/mlir/test/IR/top-level.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt --no-implicit-module --verify-diagnostics --split-input-file %s | FileCheck %s
+
+// CHECK-NOT: module
+//     CHECK: func.func
+func.func private @foo()
+
+// -----
+
+// expected-error at -3 {{source must contain a single top-level operation, found: 2}}
+func.func private @bar()
+func.func private @baz()
+
+// -----
+
+// expected-error at -3 {{source must contain a single top-level operation, found: 0}}

diff  --git a/mlir/test/Pass/pipeline-invalid.mlir b/mlir/test/Pass/pipeline-invalid.mlir
new file mode 100644
index 0000000000000..da39e4c3d4351
--- /dev/null
+++ b/mlir/test/Pass/pipeline-invalid.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt --no-implicit-module --canonicalize --verify-diagnostics --split-input-file
+
+// expected-error at below {{trying to schedule a pass on an operation not marked as 'IsolatedFromAbove'}}
+arith.constant 0
+
+// -----
+
+// expected-error at below {{trying to schedule a pass on an unregistered operation}}
+"test.op"() : () -> ()


        


More information about the Mlir-commits mailing list