[Mlir-commits] [mlir] 9bae20b - [mlir] Add shape.func
Jacques Pienaar
llvmlistbot at llvm.org
Fri Apr 22 11:37:16 PDT 2022
Author: Jacques Pienaar
Date: 2022-04-22T11:35:35-07:00
New Revision: 9bae20b52822994a74e3017722f4b445e09e993b
URL: https://github.com/llvm/llvm-project/commit/9bae20b52822994a74e3017722f4b445e09e993b
DIFF: https://github.com/llvm/llvm-project/commit/9bae20b52822994a74e3017722f4b445e09e993b.diff
LOG: [mlir] Add shape.func
Add shape func op for use (primarily) in shape function_library op. Allows
setting default dialect for some simpler authoring. This is a minimal version
of the ops needed.
Differential Revision: https://reviews.llvm.org/D124055
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/CMakeLists.txt
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Analysis/test-shape-fn-report.mlir
mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 44df58301bc13..22a348230b04b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,11 +14,13 @@
#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -995,7 +997,7 @@ def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
def Shape_FunctionLibraryOp : Shape_Op<"function_library",
[AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
- NoTerminator, SingleBlock]> {
+ NoTerminator, OpAsmOpInterface, SingleBlock]> {
let summary = "Represents shape functions and corresponding ops";
let description = [{
Represents a list of shape functions and the ops whose shape transfer
@@ -1005,8 +1007,8 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
```mlir
shape.function_library {
- func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
- %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+ func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+ %0 = shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
} mapping {
@@ -1022,7 +1024,15 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
let extraClassDeclaration = [{
/// Returns an associated shape function for an operation if defined.
- func::FuncOp getShapeFunction(Operation *op);
+ FuncOp getShapeFunction(Operation *op);
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface
+ //===------------------------------------------------------------------===//
+
+ // This will filter the `shape.` prefix in front of operations inside the
+ // func body.
+ static StringRef getDefaultDialect() { return "shape";}
}];
let builders = [OpBuilder<(ins "StringRef":$name)>];
@@ -1030,4 +1040,75 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
let hasCustomAssemblyFormat = 1;
}
+def Shape_FuncOp : Shape_Op<"func",
+ [AffineScope, AutomaticAllocationScope, CallableOpInterface,
+ FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, Symbol]> {
+ let summary = "Shape function";
+ let description = [{
+ An operation with a name containing a single `SSACFG` region which
+ represents a shape transfer function or helper function for shape transfer
+ function.
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttrOf<FunctionType>:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let regions = (region AnyRegion:$body);
+
+ let extraClassDeclaration = [{
+ //===------------------------------------------------------------------===//
+ // CallableOpInterface
+ //===------------------------------------------------------------------===//
+
+ /// Returns the region on the current operation that is callable. This may
+ /// return null in the case of an external callable object, e.g. an external
+ /// function.
+ ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+ /// Returns the results types that the callable region produces when
+ /// executed.
+ ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // FunctionOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ /// Returns the argument types of this function.
+ ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+ /// Returns the result types of this function.
+ ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+ //===------------------------------------------------------------------===//
+ // OpAsmOpInterface
+ //===------------------------------------------------------------------===//
+
+ // This will filter the `shape.` prefix in front of operations inside the
+ // func body.
+ static StringRef getDefaultDialect() { return "shape";}
+
+ //===------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===------------------------------------------------------------------===//
+
+ bool isDeclaration() { return isExternal(); }
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
+def Shape_ReturnOp : Shape_Op<"return",
+ [NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator]> {
+ let summary = "Shape function return operation";
+ let description = [{
+ The `shape.return` operation represents a return operation within a function.
+ The operation takes variable number of operands and produces no results.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+ // TODO: Tighten verification.
+}
+
#endif // SHAPE_OPS
diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index c860834c3c0fa..c18160e70f340 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRShape
LINK_LIBS PUBLIC
MLIRArithmetic
+ MLIRCallInterfaces
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRDialect
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9f8f290559efd..bbfefea3ea674 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1190,13 +1191,13 @@ void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
}
-func::FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
+FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
auto attr = getMapping()
.get(op->getName().getIdentifier())
.dyn_cast_or_null<FlatSymbolRefAttr>();
if (!attr)
return nullptr;
- return lookupSymbol<func::FuncOp>(attr);
+ return lookupSymbol<FuncOp>(attr);
}
ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
@@ -1237,6 +1238,24 @@ void FunctionLibraryOp::print(OpAsmPrinter &p) {
p.printAttributeWithoutType(getMappingAttr());
}
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto buildFuncType =
+ [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+ function_interface_impl::VariadicFlag,
+ std::string &) { return builder.getFunctionType(argTypes, results); };
+
+ return function_interface_impl::parseFunctionOp(
+ parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir
index a8559b7b9e81f..54d97483b14dc 100644
--- a/mlir/test/Analysis/test-shape-fn-report.mlir
+++ b/mlir/test/Analysis/test-shape-fn-report.mlir
@@ -15,8 +15,8 @@ func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
// The shape function library with some local functions.
shape.function_library @shape_lib {
// Test shape function that returns the shape of input arg as result shape.
- func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
- %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+ func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+ %0 = shape_of %arg : !shape.value_shape -> !shape.shape
return %0 : !shape.shape
}
} mapping {
diff --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
index e4708f15edc77..449a3e92b7da9 100644
--- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
+++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
@@ -8,6 +8,7 @@
#include <queue>
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -46,7 +47,8 @@ void ReportShapeFnPass::runOnOperation() {
return true;
}
if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
- auto fn = cast<func::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
+ auto fn =
+ cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
op->emitRemark() << "associated shape function: " << fn.getName();
return true;
}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3d3254a529b3c..58b63c464f9d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2653,8 +2653,10 @@ td_library(
],
includes = ["include"],
deps = [
+ ":CallInterfacesTdFiles",
":CastInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
+ ":FunctionInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":SideEffectInterfacesTdFiles",
],
More information about the Mlir-commits
mailing list