[Mlir-commits] [mlir] 61352a5 - [mlir] Introduce ml_program dialect.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Apr 13 21:39:06 PDT 2022
Author: Stella Laurenzo
Date: 2022-04-13T21:38:14-07:00
New Revision: 61352a580a1f8e5818a6e5445517058d959bb86f
URL: https://github.com/llvm/llvm-project/commit/61352a580a1f8e5818a6e5445517058d959bb86f
DIFF: https://github.com/llvm/llvm-project/commit/61352a580a1f8e5818a6e5445517058d959bb86f.diff
LOG: [mlir] Introduce ml_program dialect.
Differential Revision: https://reviews.llvm.org/D120203
Added:
mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
mlir/lib/Dialect/MLProgram/CMakeLists.txt
mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
mlir/test/Dialect/MLProgram/invalid.mlir
mlir/test/Dialect/MLProgram/ops.mlir
Modified:
mlir/include/mlir/Dialect/CMakeLists.txt
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/CMakeLists.txt
mlir/test/mlir-opt/commandline.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt
index d61773dc416c6..a0b5209838bc4 100644
--- a/mlir/include/mlir/Dialect/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/CMakeLists.txt
@@ -15,6 +15,7 @@ add_subdirectory(Math)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
diff --git a/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..fce18e65e952e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt
@@ -0,0 +1,3 @@
+set(LLVM_TARGET_DEFINITIONS MLProgramOps.td)
+add_mlir_dialect(MLProgramOps ml_program)
+add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc)
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
new file mode 100644
index 0000000000000..fad8cbcf1c669
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h
@@ -0,0 +1,34 @@
+//===- MLProgram.h - MLProgram dialect ----------------------------*- C++-*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// MLProgramDialect
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc"
+
+//===----------------------------------------------------------------------===//
+// MLProgram Dialect Operations
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc"
+
+#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
new file mode 100644
index 0000000000000..b670bc89204c2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td
@@ -0,0 +1,33 @@
+//===- MLProgramBase.td - Base defs for ml_program dialect --*- tablegen -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_BASE
+#define MLPROGRAM_BASE
+
+include "mlir/IR/OpBase.td"
+
+def MLProgram_Dialect : Dialect {
+ let name = "ml_program";
+ let cppNamespace = "::mlir::ml_program";
+ let description = [{
+ The MLProgram dialect contains structural operations and types for
+ defining a compiled Machine-Learning program, as created from common
+ ML frameworks, such as TensorFlow, PyTorch, JAX, etc. It does not itself
+ define computation ops common to such frameworks but establishes a common
+ programming model for establishing modules, functions, globals and
+ memory model components appropriate for such an abstract level of detail.
+
+ This dialect is under active development, and while stability is an
+ eventual goal, it is not guaranteed at this juncture. Given the early state,
+ it is recommended to inquire further prior to using this dialect.
+ }];
+
+ let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+}
+
+#endif // MLPROGRAM_BASE
diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
new file mode 100644
index 0000000000000..b096c3dd53e8c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -0,0 +1,218 @@
+//===- MLProgramOps.td - Structural ML Program Ops ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLPROGRAM_OPS
+#define MLPROGRAM_OPS
+
+include "mlir/Dialect/MLProgram/IR/MLProgramBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/FunctionInterfaces.td"
+include "mlir/IR/RegionKindInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
+
+class MLProgram_Op<string mnemonic, list<Trait> traits = []> :
+ Op<MLProgram_Dialect, mnemonic, traits>;
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_FuncOp : MLProgram_Op<"func", [
+ CallableOpInterface, FunctionOpInterface, IsolatedFromAbove,
+ RegionKindInterface, Symbol
+ ]> {
+ let summary = "Function containing a single `SSACFG` region";
+ let description = [{
+ This simple function container represents callables in an ML program where
+ the body is an `SSACFG` region. It must be terminated by a `return` op which
+ yields values with the same arity and types as the `FunctionType` results
+ of the containing `func`.
+
+ This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+ it cannot represent nested symbols.
+
+ Example:
+
+ ```mlir
+ ml_program.func private @some_extern(i32) -> i32
+ ml_program.func @compute(%arg0 : i32) -> i32 {
+ ml_program.return %arg0 : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() {
+ return isExternal() ? nullptr : &getBody();
+ }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // RegionKindInterface Methods
+ //===------------------------------------------------------------------===//
+ static ::mlir::RegionKind getRegionKind(unsigned index) {
+ return ::mlir::RegionKind::SSACFG;
+ }
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
+ CallableOpInterface, FunctionOpInterface, HasOnlyGraphRegion,
+ IsolatedFromAbove, RegionKindInterface, SingleBlock, Symbol
+ ]> {
+ let summary = "An function containing a single `Graph` region";
+ let description = [{
+ This simple function container represents callables in an ML program where
+ the body is a `Graph` region containing a single block. It must be
+ terminated by an `output` op which yields values with the same arity and
+ types as the `FunctionType` results of the containing `subgraph`.
+
+ This op is a `Symbol` but does not introduce a new `SymbolTable`. As such,
+ it cannot represented nested symbols.
+
+ Example:
+
+ ```mlir
+ ml_program.subgraph private @some_extern(i32) -> i32
+ ml_program.subgraph @compute(%arg0 : i32) -> i32 {
+ ml_program.output %arg0 : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_OutputOp : MLProgram_Op<"output", [
+ NoSideEffect, HasParent<"SubgraphOp">, ReturnLike, Terminator
+ ]> {
+ let summary = "Outputs values from a subgraph function";
+ let description = [{
+ The `output` operation terminates a subgraph by yielding values
+ to the caller.
+ The operation takes variable number of operands and produces no results.
+ The operand number and types must match the signature of the function
+ that contains the operation.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<(ins), [{
+ build($_builder, $_state, llvm::None);
+ }]>];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+def MLProgram_ReturnOp : MLProgram_Op<"return", [
+ NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator
+ ]> {
+ let summary = "Returns values from a `func` function";
+ let description = [{
+ The `return` operation terminates a `func` function by yielding values
+ to the caller.
+ The operation takes variable number of operands and produces no results.
+ The operand number and types must match the signature of the function
+ that contains the operation.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [OpBuilder<(ins), [{
+ build($_builder, $_state, llvm::None);
+ }]>];
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+ let hasVerifier = 1;
+}
+
+#endif // MLPROGRAM_OPS
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 9487876ef32af..7f370cd16bf7f 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -33,6 +33,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -77,6 +78,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
linalg::LinalgDialect,
math::MathDialect,
memref::MemRefDialect,
+ ml_program::MLProgramDialect,
scf::SCFDialect,
omp::OpenMPDialect,
pdl::PDLDialect,
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index 766dd263e7a28..6ffc8c3085a34 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -15,6 +15,7 @@ add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
add_subdirectory(MemRef)
+add_subdirectory(MLProgram)
add_subdirectory(OpenACC)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
diff --git a/mlir/lib/Dialect/MLProgram/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
new file mode 100644
index 0000000000000..f33061b2d87cf
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
new file mode 100644
index 0000000000000..61dba7539908b
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_dialect_library(MLIRMLProgram
+ MLProgramOps.cpp
+ MLProgramDialect.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MLProgram
+
+ DEPENDS
+ MLIRMLProgramOpsIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRDialect
+ MLIRInferTypeOpInterface
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
new file mode 100644
index 0000000000000..eb012bac3984a
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp
@@ -0,0 +1,21 @@
+//===- MLProgramDialect.cpp - MLProgram dialect implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+#include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc"
+
+void ml_program::MLProgramDialect::initialize() {
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
new file mode 100644
index 0000000000000..4d8038c21f17f
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp
@@ -0,0 +1,107 @@
+//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace mlir::ml_program;
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// SubgraphOp
+//===----------------------------------------------------------------------===//
+
+ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void SubgraphOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// OutputOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult OutputOp::verify() {
+ auto function = cast<SubgraphOp>((*this)->getParentOp());
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing function (@"
+ << function.getName() << ") outputs " << results.size();
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (getOperand(i).getType() != results[i])
+ return emitError() << "type of output operand " << i << " ("
+ << getOperand(i).getType()
+ << ") doesn't match function result type ("
+ << results[i] << ")"
+ << " in function @" << function.getName();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnOp::verify() {
+ auto function = cast<FuncOp>((*this)->getParentOp());
+
+ // The operand number and types must match the function signature.
+ const auto &results = function.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing function (@"
+ << function.getName() << ") returns " << results.size();
+
+ for (unsigned i = 0, e = results.size(); i != e; ++i)
+ if (getOperand(i).getType() != results[i])
+ return emitError() << "type of return operand " << i << " ("
+ << getOperand(i).getType()
+ << ") doesn't match function result type ("
+ << results[i] << ")"
+ << " in function @" << function.getName();
+
+ return success();
+}
diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir
new file mode 100644
index 0000000000000..851998a8326a2
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/invalid.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -verify-diagnostics %s
+
+ml_program.func @ssa_enforced(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{does not dominate this use}}
+ %1 = "unregistered.dummy"(%0) : (i32) -> i32
+ // expected-note @+1 {{operand defined here}}
+ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+ ml_program.return %0 : i32
+}
+
+// -----
+ml_program.func @return_arity_match(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{enclosing function (@return_arity_match) returns 1}}
+ ml_program.return %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.func @return_type_match(%arg0 : i64) -> i32 {
+ // expected-error @+1 {{doesn't match function result}}
+ ml_program.return %arg0 : i64
+}
+
+// -----
+ml_program.subgraph @output_arity_match(%arg0 : i32) -> i32 {
+ // expected-error @+1 {{enclosing function (@output_arity_match) outputs 1}}
+ ml_program.output %arg0, %arg0 : i32, i32
+}
+
+// -----
+ml_program.subgraph @output_type_match(%arg0 : i64) -> i32 {
+ // expected-error @+1 {{doesn't match function result}}
+ ml_program.output %arg0 : i64
+}
diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir
new file mode 100644
index 0000000000000..24f5f8af1be7c
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/ops.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --allow-unregistered-dialect --mlir-print-op-generic | mlir-opt --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: ml_program.func private @extern_func
+ml_program.func private @extern_func(i32) -> i32
+
+// CHECK-LABEL: ml_program.func @defined_func
+ml_program.func @defined_func(%arg0 : i32) -> i32 {
+ ml_program.return %arg0 : i32
+}
+
+// CHECK-LABEL: ml_program.subgraph private @extern_subgraph
+ml_program.subgraph private @extern_subgraph(i32) -> i32
+
+// CHECK-LABEL: ml_program.subgraph @compute_subgraph
+ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 {
+ %1 = "unregistered.dummy"(%0) : (i32) -> i32
+ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32
+ ml_program.output %0 : i32
+}
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 29cef56b7ac68..5ea1fd6e61394 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -19,6 +19,7 @@
// CHECK-NEXT: llvm
// CHECK-NEXT: math
// CHECK-NEXT: memref
+// CHECK-NEXT: ml_program
// CHECK-NEXT: nvvm
// CHECK-NEXT: omp
// CHECK-NEXT: pdl
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index ba87c8e499600..85d94212fd726 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5939,6 +5939,7 @@ cc_library(
":LinalgToSPIRV",
":LinalgToStandard",
":LinalgTransforms",
+ ":MLProgramDialect",
":MathDialect",
":MathToLLVM",
":MathToLibm",
@@ -8114,6 +8115,77 @@ cc_library(
],
)
+##---------------------------------------------------------------------------##
+# MLProgram dialect
+##---------------------------------------------------------------------------##
+
+td_library(
+ name = "MLProgramOpsTdFiles",
+ srcs = [
+ "include/mlir/Dialect/MLProgram/IR/MLProgramBase.td",
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":CallInterfacesTdFiles",
+ ":ControlFlowInterfacesTdFiles",
+ ":FunctionInterfacesTdFiles",
+ ":OpBaseTdFiles",
+ ":RegionKindInterfaceIncGen",
+ ":SideEffectInterfacesTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "MLProgramOpsIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-decls"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.h.inc",
+ ),
+ (
+ ["-gen-op-defs"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc",
+ ),
+ (
+ ["-gen-dialect-decls"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.h.inc",
+ ),
+ (
+ ["-gen-dialect-defs"],
+ "include/mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td",
+ deps = [":MLProgramOpsTdFiles"],
+)
+
+cc_library(
+ name = "MLProgramDialect",
+ srcs = glob([
+ "lib/Dialect/MLProgram/IR/*.cpp",
+ "lib/Dialect/MLProgram/IR/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/MLProgram/IR/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":ControlFlowInterfaces",
+ ":IR",
+ ":MLProgramOpsIncGen",
+ ":Pass",
+ ":Support",
+ "//llvm:Support",
+ ],
+)
+
+##---------------------------------------------------------------------------##
+# Allocation interfaces
+##---------------------------------------------------------------------------##
+
td_library(
name = "AllocationOpInterfaceTdFiles",
srcs = ["include/mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"],
More information about the Mlir-commits
mailing list