[Mlir-commits] [mlir] 61352a5 - [mlir] Introduce ml_program dialect.

Stella Laurenzo llvmlistbot at llvm.org
Wed Apr 13 21:39:06 PDT 2022


Author: Stella Laurenzo
Date: 2022-04-13T21:38:14-07:00
New Revision: 61352a580a1f8e5818a6e5445517058d959bb86f

URL: https://github.com/llvm/llvm-project/commit/61352a580a1f8e5818a6e5445517058d959bb86f
DIFF: https://github.com/llvm/llvm-project/commit/61352a580a1f8e5818a6e5445517058d959bb86f.diff

LOG: [mlir] Introduce ml_program dialect.

Differential Revision: https://reviews.llvm.org/D120203

Added: 
    mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
    mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
    mlir/lib/Dialect/MLProgram/CMakeLists.txt
    mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
    mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
    mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
    mlir/test/Dialect/MLProgram/invalid.mlir
    mlir/test/Dialect/MLProgram/ops.mlir

Modified: 
    mlir/include/mlir/Dialect/CMakeLists.txt
    mlir/include/mlir/InitAllDialects.h
    mlir/lib/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index d61773dc416c6..a0b5209838bc4 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -15,6 +15,7 @@ add_subdirectory(Math)
 add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)

diff  --git a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..fce18e65e952e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
+add_mlir_dialect(MLProgramOps ml_program)
+add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
new file mode 100644
index 0000000000000..fad8cbcf1c669
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
@@ -0,0 +1,34 @@
+//===- MLProgram.h - MLProgram dialect ----------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// MLProgramDialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// MLProgram Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc"
+
+#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
new file mode 100644
index 0000000000000..b670bc89204c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
@@ -0,0 +1,33 @@
+//===- MLProgramBase.td - Base defs for ml_program dialect --*- tablegen -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_BASE
+#define MLPROGRAM_BASE
+
+include "mlir/IR/OpBase.td"
+
+def MLProgram_Dialect : Dialect {
+  let name = "ml_program";
+  let cppNamespace = "::mlir::ml_program";
+  let description = [{
+    The MLProgram dialect contains structural operations and types for
+    defining a compiled Machine-Learning program, as created from common
+    ML frameworks, such as TensorFlow, PyTorch, JAX, etc. It does not itself
+    define computation ops common to such frameworks but establishes a common
+    programming model for establishing modules, functions, globals and
+    memory model components appropriate for such an abstract level of detail.
+
+    This dialect is under active development, and while stability is an
+    eventual goal, it is not guaranteed at this juncture. Given the early state,
+    it is recommended to inquire further prior to using this dialect.
+  }];
+
+  let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+}
+
+#endif // MLPROGRAM_BASE

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
new file mode 100644
index 0000000000000..b096c3dd53e8c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -0,0 +1,218 @@
+//===- MLProgramOps.td - Structural ML Program Ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_OPS
+#define MLPROGRAM_OPS
+
+include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/FunctionInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+class MLProgram_Op<string mnemonic, list<Trait> traits = []> :
+    Op<MLProgram_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_FuncOp : MLProgram_Op<"func", [
+    CallableOpInterface, FunctionOpInterface, IsolatedFromAbove,
+    RegionKindInterface, Symbol
+  ]> {
+  let summary = "Function containing a single `SSACFG` region";
+  let description = [{
+    This simple function container represents callables in an ML program where
+    the body is an `SSACFG` region. It must be terminated by a `return` op which
+    yields values with the same arity and types as the `FunctionType` results
+    of the containing `func`.
+
+    This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+    it cannot represent nested symbols.
+
+    Example:
+
+    ```mlir
+    ml_program.func private @some_extern(i32) -> i32
+    ml_program.func @compute(%arg0 : i32) -> i32 {
+      ml_program.return %arg0 : i32
+    }
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() {
+      return isExternal() ? nullptr : &getBody();
+    }
+
+    /// Returns the results types that the callable region produces when
+    /// executed.
+    ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // RegionKindInterface Methods
+    //===------------------------------------------------------------------===//
+    static ::mlir::RegionKind getRegionKind(unsigned index) {
+      return ::mlir::RegionKind::SSACFG;
+    }
+
+    //===------------------------------------------------------------------===//
+    // SymbolOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    bool isDeclaration() { return isExternal(); }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
+    CallableOpInterface, FunctionOpInterface, HasOnlyGraphRegion,
+    IsolatedFromAbove, RegionKindInterface, SingleBlock, Symbol
+  ]> {
+  let summary = "An function containing a single `Graph` region";
+  let description = [{
+    This simple function container represents callables in an ML program where
+    the body is a `Graph` region containing a single block. It must be
+    terminated by an `output` op which yields values with the same arity and
+    types as the `FunctionType` results of the containing `subgraph`.
+
+    This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+    it cannot represented nested symbols.
+
+    Example:
+
+    ```mlir
+    ml_program.subgraph private @some_extern(i32) -> i32
+    ml_program.subgraph @compute(%arg0 : i32) -> i32 {
+      ml_program.output %arg0 : i32
+    }
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+    /// Returns the results types that the callable region produces when
+    /// executed.
+    ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // SymbolOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    bool isDeclaration() { return isExternal(); }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_OutputOp : MLProgram_Op<"output", [
+    NoSideEffect, HasParent<"SubgraphOp">, ReturnLike, Terminator
+  ]> {
+  let summary = "Outputs values from a subgraph function";
+  let description = [{
+    The `output` operation terminates a subgraph by yielding values
+    to the caller.
+    The operation takes variable number of operands and produces no results.
+    The operand number and types must match the signature of the function
+    that contains the operation.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [OpBuilder<(ins), [{
+    build($_builder, $_state, llvm::None);
+  }]>];
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_ReturnOp : MLProgram_Op<"return", [
+    NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator
+  ]> {
+  let summary = "Returns values from a `func` function";
+  let description = [{
+    The `return` operation terminates a `func` function by yielding values
+    to the caller.
+    The operation takes variable number of operands and produces no results.
+    The operand number and types must match the signature of the function
+    that contains the operation.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let builders = [OpBuilder<(ins), [{
+    build($_builder, $_state, llvm::None);
+  }]>];
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+  let hasVerifier = 1;
+}
+
+#endif // MLPROGRAM_OPS

diff  --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 9487876ef32af..7f370cd16bf7f 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -33,6 +33,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -77,6 +78,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   linalg::LinalgDialect,
                   math::MathDialect,
                   memref::MemRefDialect,
+                  ml_program::MLProgramDialect,
                   scf::SCFDialect,
                   omp::OpenMPDialect,
                   pdl::PDLDialect,

diff  --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 766dd263e7a28..6ffc8c3085a34 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -15,6 +15,7 @@ add_subdirectory(Linalg)
 add_subdirectory(LLVMIR)
 add_subdirectory(Math)
 add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
 add_subdirectory(OpenACC)
 add_subdirectory(OpenMP)
 add_subdirectory(PDL)

diff  --git a/mlir/lib/Dialect/MLProgram/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)

diff  --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..61dba7539908b
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRMLProgram
+  MLProgramOps.cpp
+  MLProgramDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram
+
+  DEPENDS
+  MLIRMLProgramOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRDialect
+  MLIRInferTypeOpInterface
+  MLIRIR
+  )

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
new file mode 100644
index 0000000000000..eb012bac3984a
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
@@ -0,0 +1,21 @@
+//===- MLProgramDialect.cpp - MLProgram dialect implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
+
+void ml_program::MLProgramDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+      >();
+}

diff  --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
new file mode 100644
index 0000000000000..4d8038c21f17f
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -0,0 +1,107 @@
+//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void SubgraphOp::print(OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OutputOp::verify() {
+  auto function = cast<SubgraphOp>((*this)->getParentOp());
+
+  // The operand number and types must match the function signature.
+  const auto &results = function.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing function (@"
+           << function.getName() << ") outputs " << results.size();
+
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of output operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match function result type ("
+                         << results[i] << ")"
+                         << " in function @" << function.getName();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnOp::verify() {
+  auto function = cast<FuncOp>((*this)->getParentOp());
+
+  // The operand number and types must match the function signature.
+  const auto &results = function.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing function (@"
+           << function.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0, e = results.size(); i != e; ++i)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match function result type ("
+                         << results[i] << ")"
+                         << " in function @" << function.getName();
+
+  return success();
+}

diff  --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir
new file mode 100644
index 0000000000000..851998a8326a2
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -verify-diagnostics %s
+
+ml_program.func @ssa_enforced(%arg0 : i32) -> i32 {
+  // expected-error @+1 {{does not dominate this use}}
+  %1 = "unregistered.dummy"(%0) : (i32) -> i32
+  // expected-note @+1 {{operand defined here}}
+  %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+  ml_program.return %0 : i32
+}
+
+// -----
+ml_program.func @return_arity_match(%arg0 : i32) -> i32 {
+  // expected-error @+1 {{enclosing function (@return_arity_match) returns 1}}
+  ml_program.return %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.func @return_type_match(%arg0 : i64) -> i32 {
+  // expected-error @+1 {{doesn't match function result}}
+  ml_program.return %arg0 : i64
+}
+
+// -----
+ml_program.subgraph @output_arity_match(%arg0 : i32) -> i32 {
+  // expected-error @+1 {{enclosing function (@output_arity_match) outputs 1}}
+  ml_program.output %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
+  // expected-error @+1 {{doesn't match function result}}
+  ml_program.output %arg0 : i64
+}

diff  --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir
new file mode 100644
index 0000000000000..24f5f8af1be7c
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/ops.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --mlir-print-op-generic | mlir-opt --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: ml_program.func private @extern_func
+ml_program.func private @extern_func(i32) -> i32
+
+// CHECK-LABEL: ml_program.func @defined_func
+ml_program.func @defined_func(%arg0 : i32) -> i32 {
+  ml_program.return %arg0 : i32
+}
+
+// CHECK-LABEL: ml_program.subgraph private @extern_subgraph
+ml_program.subgraph private @extern_subgraph(i32) -> i32
+
+// CHECK-LABEL: ml_program.subgraph @compute_subgraph
+ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
+  %1 = "unregistered.dummy"(%0) : (i32) -> i32
+  %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+  ml_program.output %0 : i32
+}

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 29cef56b7ac68..5ea1fd6e61394 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -19,6 +19,7 @@
 // CHECK-NEXT: llvm
 // CHECK-NEXT: math
 // CHECK-NEXT: memref
+// CHECK-NEXT: ml_program
 // CHECK-NEXT: nvvm
 // CHECK-NEXT: omp
 // CHECK-NEXT: pdl

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ba87c8e499600..85d94212fd726 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5939,6 +5939,7 @@ cc_library(
         ":LinalgToSPIRV",
         ":LinalgToStandard",
         ":LinalgTransforms",
+        ":MLProgramDialect",
         ":MathDialect",
         ":MathToLLVM",
         ":MathToLibm",
@@ -8114,6 +8115,77 @@ cc_library(
     ],
 )
 
+##---------------------------------------------------------------------------##
+# MLProgram dialect
+##---------------------------------------------------------------------------##
+
+td_library(
+    name = "MLProgramOpsTdFiles",
+    srcs = [
+        "include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
+        "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+    ],
+    includes = ["include"],
+    deps = [
+        ":CallInterfacesTdFiles",
+        ":ControlFlowInterfacesTdFiles",
+        ":FunctionInterfacesTdFiles",
+        ":OpBaseTdFiles",
+        ":RegionKindInterfaceIncGen",
+        ":SideEffectInterfacesTdFiles",
+    ],
+)
+
+gentbl_cc_library(
+    name = "MLProgramOpsIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-decls"],
+            "include/mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc",
+        ),
+        (
+            ["-gen-op-defs"],
+            "include/mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc",
+        ),
+        (
+            ["-gen-dialect-decls"],
+            "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc",
+        ),
+        (
+            ["-gen-dialect-defs"],
+            "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+    deps = [":MLProgramOpsTdFiles"],
+)
+
+cc_library(
+    name = "MLProgramDialect",
+    srcs = glob([
+        "lib/Dialect/MLProgram/IR/*.cpp",
+        "lib/Dialect/MLProgram/IR/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Dialect/MLProgram/IR/*.h",
+    ]),
+    includes = ["include"],
+    deps = [
+        ":ControlFlowInterfaces",
+        ":IR",
+        ":MLProgramOpsIncGen",
+        ":Pass",
+        ":Support",
+        "//llvm:Support",
+    ],
+)
+
+##---------------------------------------------------------------------------##
+# Allocation interfaces
+##---------------------------------------------------------------------------##
+
 td_library(
     name = "AllocationOpInterfaceTdFiles",
     srcs = ["include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"],


        


More information about the Mlir-commits mailing list