[Mlir-commits] [mlir] 9bae20b - [mlir] Add shape.func

Jacques Pienaar llvmlistbot at llvm.org
Fri Apr 22 11:37:16 PDT 2022


Author: Jacques Pienaar
Date: 2022-04-22T11:35:35-07:00
New Revision: 9bae20b52822994a74e3017722f4b445e09e993b

URL: https://github.com/llvm/llvm-project/commit/9bae20b52822994a74e3017722f4b445e09e993b
DIFF: https://github.com/llvm/llvm-project/commit/9bae20b52822994a74e3017722f4b445e09e993b.diff

LOG: [mlir] Add shape.func

Add shape func op for use (primarily) in shape function_library op. Allows
setting default dialect for some simpler authoring. This is a minimal version
of the ops needed.

Differential Revision: https://reviews.llvm.org/D124055

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/CMakeLists.txt
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Analysis/test-shape-fn-report.mlir
    mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 44df58301bc13..22a348230b04b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -14,11 +14,13 @@
 #define SHAPE_OPS
 
 include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -995,7 +997,7 @@ def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
 
 def Shape_FunctionLibraryOp : Shape_Op<"function_library",
     [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
-     NoTerminator, SingleBlock]> {
+     NoTerminator, OpAsmOpInterface, SingleBlock]> {
   let summary = "Represents shape functions and corresponding ops";
   let description = [{
     Represents a list of shape functions and the ops whose shape transfer
@@ -1005,8 +1007,8 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
 
     ```mlir
     shape.function_library {
-      func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
-        %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+      func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+        %0 = shape_of %arg : !shape.value_shape -> !shape.shape
         return %0 : !shape.shape
       }
     } mapping {
@@ -1022,7 +1024,15 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
 
   let extraClassDeclaration = [{
     /// Returns an associated shape function for an operation if defined.
-    func::FuncOp getShapeFunction(Operation *op);
+    FuncOp getShapeFunction(Operation *op);
+
+    //===------------------------------------------------------------------===//
+    // OpAsmOpInterface
+    //===------------------------------------------------------------------===//
+
+    // This will filter the `shape.` prefix in front of operations inside the
+    // func body.
+    static StringRef getDefaultDialect() { return "shape";}
   }];
 
   let builders = [OpBuilder<(ins "StringRef":$name)>];
@@ -1030,4 +1040,75 @@ def Shape_FunctionLibraryOp : Shape_Op<"function_library",
   let hasCustomAssemblyFormat = 1;
 }
 
+def Shape_FuncOp : Shape_Op<"func",
+    [AffineScope, AutomaticAllocationScope, CallableOpInterface,
+     FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface, Symbol]> {
+  let summary = "Shape function";
+  let description = [{
+    An operation with a name containing a single `SSACFG` region which
+    represents a shape transfer function or helper function for shape transfer
+    function.
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); }
+
+    /// Returns the results types that the callable region produces when
+    /// executed.
+    ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the argument types of this function.
+    ArrayRef<Type> getArgumentTypes() { return getFunctionType().getInputs(); }
+
+    /// Returns the result types of this function.
+    ArrayRef<Type> getResultTypes() { return getFunctionType().getResults(); }
+
+    //===------------------------------------------------------------------===//
+    // OpAsmOpInterface
+    //===------------------------------------------------------------------===//
+
+    // This will filter the `shape.` prefix in front of operations inside the
+    // func body.
+    static StringRef getDefaultDialect() { return "shape";}
+
+    //===------------------------------------------------------------------===//
+    // SymbolOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    bool isDeclaration() { return isExternal(); }
+  }];
+  let hasCustomAssemblyFormat = 1;
+}
+
+def Shape_ReturnOp : Shape_Op<"return",
+    [NoSideEffect, HasParent<"FuncOp">, ReturnLike, Terminator]> {
+  let summary = "Shape function return operation";
+  let description = [{
+    The `shape.return` operation represents a return operation within a function.
+    The operation takes variable number of operands and produces no results.
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+  // TODO: Tighten verification.
+}
+
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
index c860834c3c0fa..c18160e70f340 100644
--- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRShape
 
   LINK_LIBS PUBLIC
   MLIRArithmetic
+  MLIRCallInterfaces
   MLIRCastInterfaces
   MLIRControlFlowInterfaces
   MLIRDialect

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 9f8f290559efd..bbfefea3ea674 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -1190,13 +1191,13 @@ void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
       ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
 }
 
-func::FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
+FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
   auto attr = getMapping()
                   .get(op->getName().getIdentifier())
                   .dyn_cast_or_null<FlatSymbolRefAttr>();
   if (!attr)
     return nullptr;
-  return lookupSymbol<func::FuncOp>(attr);
+  return lookupSymbol<FuncOp>(attr);
 }
 
 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
@@ -1237,6 +1238,24 @@ void FunctionLibraryOp::print(OpAsmPrinter &p) {
   p.printAttributeWithoutType(getMappingAttr());
 }
 
+//===----------------------------------------------------------------------===//
+// FuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false, buildFuncType);
+}
+
+void FuncOp::print(OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false);
+}
+
 //===----------------------------------------------------------------------===//
 // GetExtentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir
index a8559b7b9e81f..54d97483b14dc 100644
--- a/mlir/test/Analysis/test-shape-fn-report.mlir
+++ b/mlir/test/Analysis/test-shape-fn-report.mlir
@@ -15,8 +15,8 @@ func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
 // The shape function library with some local functions.
 shape.function_library @shape_lib {
   // Test shape function that returns the shape of input arg as result shape.
-  func.func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
-    %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+  func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+    %0 = shape_of %arg : !shape.value_shape -> !shape.shape
     return %0 : !shape.shape
   }
 } mapping {

diff  --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
index e4708f15edc77..449a3e92b7da9 100644
--- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
+++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
@@ -8,6 +8,7 @@
 
 #include <queue>
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -46,7 +47,8 @@ void ReportShapeFnPass::runOnOperation() {
       return true;
     }
     if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
-      auto fn = cast<func::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
+      auto fn =
+          cast<shape::FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
       op->emitRemark() << "associated shape function: " << fn.getName();
       return true;
     }

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3d3254a529b3c..58b63c464f9d4 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2653,8 +2653,10 @@ td_library(
     ],
     includes = ["include"],
     deps = [
+        ":CallInterfacesTdFiles",
         ":CastInterfacesTdFiles",
         ":ControlFlowInterfacesTdFiles",
+        ":FunctionInterfacesTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":SideEffectInterfacesTdFiles",
     ],


        


More information about the Mlir-commits mailing list