[Mlir-commits] [mlir] 6dd9596 - [mlir] Add a shape function library op

Jacques Pienaar llvmlistbot at llvm.org
Sat Nov 28 15:54:22 PST 2020


Author: Jacques Pienaar
Date: 2020-11-28T15:53:59-08:00
New Revision: 6dd9596b19d7679c562f8e866be6d0c3d7c21994

URL: https://github.com/llvm/llvm-project/commit/6dd9596b19d7679c562f8e866be6d0c3d7c21994
DIFF: https://github.com/llvm/llvm-project/commit/6dd9596b19d7679c562f8e866be6d0c3d7c21994.diff

LOG: [mlir] Add a shape function library op

Op with mapping from ops to corresponding shape functions for those op
in the library and mechanism to associate shape functions to functions.
The mapping of operand to shape function is kept separate from the shape
functions themselves as the operation is associated to the shape
function and not vice versa, and one could have a common library of
shape functions that can be used in different contexts.

Use fully qualified names and require a name for shape fn lib ops for
now and an explicit print/parse (based around the generated one & GPU
module op ones).

Differential Revision: https://reviews.llvm.org/D91672

Added: 
    mlir/test/Analysis/test-shape-fn-report.mlir
    mlir/test/lib/Dialect/Shape/CMakeLists.txt
    mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/Shape.h
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index f40d6154544a..cb5ed56e16a2 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_SHAPE_IR_SHAPE_H
 #define MLIR_SHAPE_IR_SHAPE_H
 
+#include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a852d900cf69..52768e49001d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // Shape op definitions
@@ -492,7 +493,7 @@ def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
 }
 
 def Shape_YieldOp : Shape_Op<"yield",
-    [HasParent<"ReduceOp">,
+    [HasParent<"ReduceOp, FunctionLibraryOp">,
      NoSideEffect,
      ReturnLike,
      Terminator]> {
@@ -780,4 +781,62 @@ def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Shape collection ops.
+//===----------------------------------------------------------------------===//
+
+def Shape_FunctionLibraryOp : Shape_Op<"function_library",
+    [AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
+     SingleBlockImplicitTerminator<"ShapeFunctionLibraryTerminatorOp">]> {
+  let summary = "Represents shape functions and corresponding ops";
+  let description = [{
+    Represents a list of shape functions and the ops whose shape transfer
+    functions they represent.
+
+    Example:
+
+    ```mlir
+    shape.function_library {
+      func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+        %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+        return %0 : !shape.shape
+      }
+    } mapping {
+      std.atan = @same_result_shape
+    }
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let arguments = (ins DictionaryAttr:$mapping);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    /// Returns an associated shape function for an operation if defined.
+    FuncOp getShapeFunction(Operation *op);
+  }];
+
+  let builders = [OpBuilderDAG<(ins "StringRef":$name)>];
+  let skipDefaultBuilders = 1;
+
+  let printer = [{ ::print(p, *this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeFunctionLibraryTerminatorOp
+//===----------------------------------------------------------------------===//
+
+def ShapeFunctionLibraryTerminatorOp : Shape_Op<"fn_lib_terminator",
+    [Terminator, HasParent<"FunctionLibraryOp">]> {
+  let summary = "A pseudo op that marks the end of a shape function library";
+  let description = [{
+    `shape_fn_lib_terminator` is a special pseudo terminator operation for the
+    shape function library. It has no semantic meaning beyond keeping the body
+    well-formed.
+  }];
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // SHAPE_OPS

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index cfac2abae3e6..d8c7f4c6736d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Traits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -558,6 +559,65 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
   return builder.getIndexTensorAttr(extents);
 }
 
+//===----------------------------------------------------------------------===//
+// FunctionLibraryOp
+//===----------------------------------------------------------------------===//
+
+void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
+                              StringRef name) {
+  ensureTerminator(*result.addRegion(), builder, result.location);
+  result.attributes.push_back(builder.getNamedAttr(
+      ::mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)));
+}
+
+FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
+  auto attr = mapping()
+                  .get(op->getName().getIdentifier())
+                  .dyn_cast_or_null<FlatSymbolRefAttr>();
+  if (!attr)
+    return nullptr;
+  return lookupSymbol<FuncOp>(attr);
+}
+
+ParseResult parseFunctionLibraryOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  // Parse the op name.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  auto *bodyRegion = result.addRegion();
+  if (parser.parseRegion(*bodyRegion))
+    return failure();
+
+  FunctionLibraryOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
+                                      result.location);
+  if (parser.parseKeyword("mapping"))
+    return failure();
+
+  DictionaryAttr mappingAttr;
+  if (parser.parseAttribute(mappingAttr,
+                            parser.getBuilder().getType<NoneType>(), "mapping",
+                            result.attributes))
+    return failure();
+  return success();
+}
+
+void print(OpAsmPrinter &p, FunctionLibraryOp op) {
+  p << op.getOperationName() << ' ';
+  p.printSymbolName(op.getName());
+  p.printOptionalAttrDictWithKeyword(
+      op.getAttrs(), {SymbolTable::getSymbolAttrName(), "mapping"});
+  p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/false);
+  p << " mapping ";
+  p.printAttributeWithoutType(op.mappingAttr());
+}
+
 //===----------------------------------------------------------------------===//
 // GetExtentOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir
new file mode 100644
index 000000000000..ad5c8e64a1b7
--- /dev/null
+++ b/mlir/test/Analysis/test-shape-fn-report.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics
+
+// expected-remark at +1 {{associated shape function: same_result_shape}}
+func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
+    attributes {shape.function = @shape_lib::@same_result_shape} {
+  // expected-remark at +1 {{no associated way}}
+  %0 = tanh %arg : tensor<10x20xf32>
+  // expected-remark at +1 {{associated shape function: same_result_shape}}
+  %1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
+  return %1 : tensor<10x20xf32>
+}
+
+// The shape function library with some local functions.
+shape.function_library @shape_lib {
+  // Test shape function that returns the shape of input arg as result shape.
+  func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+    %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+    return %0 : !shape.shape
+  }
+} mapping {
+  test.same_operand_result_type = @same_result_shape
+}

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index b220d0d81632..adee9f8a1514 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_subdirectory(Affine)
+add_subdirectory(Shape)
 add_subdirectory(SPIRV)
 add_subdirectory(Test)
 add_subdirectory(Tosa)

diff  --git a/mlir/test/lib/Dialect/Shape/CMakeLists.txt b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
new file mode 100644
index 000000000000..6c041ab9c371
--- /dev/null
+++ b/mlir/test/lib/Dialect/Shape/CMakeLists.txt
@@ -0,0 +1,16 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRShapeTestPasses
+  TestShapeFunctions.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRPass
+  MLIRShape
+  MLIRSupport
+  )

diff  --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
new file mode 100644
index 000000000000..688f24e5ec47
--- /dev/null
+++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
@@ -0,0 +1,73 @@
+//===- TestShapeFunctions.cpp - Passes to test shape function  ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <queue>
+
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a pass that reports shape functions associated with ops.
+struct ReportShapeFnPass
+    : public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
+  void runOnOperation() override;
+};
+} // end anonymous namespace
+
+void ReportShapeFnPass::runOnOperation() {
+  auto module = getOperation();
+
+  // Lookup shape function library.
+  shape::FunctionLibraryOp shapeFnLib = nullptr;
+  for (auto lib : module.getOps<shape::FunctionLibraryOp>()) {
+    if (shapeFnLib) {
+      lib.emitError("duplicate shape library op")
+              .attachNote(shapeFnLib.getLoc())
+          << "previous mapping";
+      return signalPassFailure();
+    }
+    shapeFnLib = lib;
+  };
+
+  // Report the shape function available to refine the op.
+  auto shapeFnId = Identifier::get("shape.function", &getContext());
+  auto remarkShapeFn = [&](Operation *op) {
+    if (op->isKnownTerminator())
+      return;
+    if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
+      op->emitRemark() << "implements InferType op interface";
+    } else if (auto fn = shapeFnLib.getShapeFunction(op)) {
+      op->emitRemark() << "associated shape function: " << fn.getName();
+    } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
+      auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
+      op->emitRemark() << "associated shape function: " << fn.getName();
+    } else {
+      op->emitRemark() << "no associated way to refine shape";
+    }
+  };
+
+  module.getBodyRegion().walk([&](FuncOp func) {
+    // Skip ops in the shape function library.
+    if (isa<shape::FunctionLibraryOp>(func.getParentOp()))
+      return;
+
+    func.walk([&](Operation *op) { remarkShapeFn(op); });
+  });
+}
+
+namespace mlir {
+void registerShapeFunctionTestPasses() {
+  PassRegistration<ReportShapeFnPass>(
+      "test-shape-function-report",
+      "Test pass to report associated shape functions");
+}
+} // namespace mlir

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index aef5b5166ae2..5a17eebfd32c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -134,6 +134,12 @@ def VariadicWithSameOperandsResult :
   let results = (outs AnySignlessInteger:$result);
 }
 
+def SameOperandsResultType : TEST_Op<
+    "same_operand_result_type", [SameOperandsAndResultType]> {
+  let arguments = (ins AnyTensor:$operand);
+  let results = (outs AnyTensor:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Results
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 483dcfec0c0f..e8b0842a9e33 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -13,6 +13,7 @@ set(LLVM_LINK_COMPONENTS
 if(MLIR_INCLUDE_TESTS)
   set(test_libs
     MLIRAffineTransformsTestPasses
+    MLIRShapeTestPasses
     MLIRSPIRVTestPasses
     MLIRTestDialect
     MLIRTestIR

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a0e36cf82534..4095cc21cbaf 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -32,6 +32,7 @@ namespace mlir {
 void registerConvertToTargetEnvPass();
 void registerPassManagerTestPass();
 void registerPrintOpAvailabilityPass();
+void registerShapeFunctionTestPasses();
 void registerSideEffectTestPasses();
 void registerSliceAnalysisTestPass();
 void registerSymbolTestPasses();
@@ -98,6 +99,7 @@ void registerTestPasses() {
   registerConvertToTargetEnvPass();
   registerPassManagerTestPass();
   registerPrintOpAvailabilityPass();
+  registerShapeFunctionTestPasses();
   registerSideEffectTestPasses();
   registerSliceAnalysisTestPass();
   registerSymbolTestPasses();


        


More information about the Mlir-commits mailing list