[llvm-branch-commits] [mlir] 8d541a1 - [mlir][shape] Add shape.lib attribute

Jacques Pienaar via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 31 14:51:25 PST 2020


Author: Jacques Pienaar
Date: 2020-12-31T14:46:08-08:00
New Revision: 8d541a1fbe6d92a3fadf6d7d8e8209ed6c76e092

URL: https://github.com/llvm/llvm-project/commit/8d541a1fbe6d92a3fadf6d7d8e8209ed6c76e092
DIFF: https://github.com/llvm/llvm-project/commit/8d541a1fbe6d92a3fadf6d7d8e8209ed6c76e092.diff

LOG: [mlir][shape] Add shape.lib attribute

Enable querying shape function library ops from the module. Currently
supports singular or array of them (as long as array has all unique ops
in mappings). The preferred canonical form would have one library, but
given the invariant on the mapping, this can easily be achieved by a
simple merging pass.

Preferred the attribute approach vs naming convention as these could be
added in multiple different ways.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Analysis/test-shape-fn-report.mlir
    mlir/test/Dialect/Shape/invalid.mlir
    mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
index a7868e74c65f..1cccb59dfbb9 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -37,6 +37,7 @@ def ShapeDialect : Dialect {
   let cppNamespace = "::mlir::shape";
 
   let hasConstantMaterializer = 1;
+  let hasOperationAttrVerify = 1;
 }
 
 def Shape_ShapeType : DialectType<ShapeDialect,

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 0478cb7872cc..2de60ebe1306 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -145,6 +145,56 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
       .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); });
 }
 
+LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
+                                                     NamedAttribute attribute) {
+  // Verify shape.lib attribute.
+  if (attribute.first == "shape.lib") {
+    if (!op->hasTrait<OpTrait::SymbolTable>())
+      return op->emitError(
+          "shape.lib attribute may only be on op implementing SymbolTable");
+
+    if (auto symbolRef = attribute.second.dyn_cast<SymbolRefAttr>()) {
+      auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
+      if (!symbol)
+        return op->emitError("shape function library ")
+               << symbolRef << " not found";
+      return isa<shape::FunctionLibraryOp>(symbol)
+                 ? success()
+                 : op->emitError()
+                       << symbolRef << " required to be shape function library";
+    }
+
+    if (auto arr = attribute.second.dyn_cast<ArrayAttr>()) {
+      // Verify all entries are function libraries and mappings in libraries
+      // refer to unique ops.
+      DenseSet<Identifier> key;
+      for (auto it : arr) {
+        if (!it.isa<SymbolRefAttr>())
+          return op->emitError(
+              "only SymbolRefAttr allowed in shape.lib attribute array");
+
+        auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
+            SymbolTable::lookupSymbolIn(op, it.cast<SymbolRefAttr>()));
+        if (!shapeFnLib)
+          return op->emitError()
+                 << it << " does not refer to FunctionLibraryOp";
+        for (auto mapping : shapeFnLib.mapping()) {
+          if (!key.insert(mapping.first).second) {
+            return op->emitError("only one op to shape mapping allowed, found "
+                                 "multiple for `")
+                   << mapping.first << "`";
+          }
+        }
+      }
+      return success();
+    }
+
+    return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
+                         "allowed as shape.lib attribute");
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // AnyOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Analysis/test-shape-fn-report.mlir b/mlir/test/Analysis/test-shape-fn-report.mlir
index ad5c8e64a1b7..b01593531502 100644
--- a/mlir/test/Analysis/test-shape-fn-report.mlir
+++ b/mlir/test/Analysis/test-shape-fn-report.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt %s --test-shape-function-report -verify-diagnostics
 
+module attributes {shape.lib = [@shape_lib]} {
+
 // expected-remark at +1 {{associated shape function: same_result_shape}}
 func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
     attributes {shape.function = @shape_lib::@same_result_shape} {
@@ -20,3 +22,5 @@ shape.function_library @shape_lib {
 } mapping {
   test.same_operand_result_type = @same_result_shape
 }
+
+}

diff  --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir
index eb0ae5ae05a9..d2f5af2f7b30 100644
--- a/mlir/test/Dialect/Shape/invalid.mlir
+++ b/mlir/test/Dialect/Shape/invalid.mlir
@@ -154,3 +154,95 @@ func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex
       : !shape.shape, tensor<?xindex> -> tensor<?xindex>
   return %result : tensor<?xindex>
 }
+
+// -----
+
+// Test using an unsupported shape.lib attribute type.
+
+// expected-error at +1 {{only SymbolRefAttr allowed in shape.lib attribute array}}
+module attributes {shape.lib = [@shape_lib, "shape_lib"]} {
+
+shape.function_library @shape_lib {
+  // Test shape function that returns the shape of input arg as result shape.
+  func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+    %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+    return %0 : !shape.shape
+  }
+} mapping {
+  test.same_operand_result_type = @same_result_shape
+}
+
+}
+
+// -----
+
+// Test that duplicate op to shape function mappings are flagged, this uses
+// the same library twice for easy overlap.
+
+// expected-error at +1 {{only one op to shape mapping allowed}}
+module attributes {shape.lib = [@shape_lib, @shape_lib]} {
+
+shape.function_library @shape_lib {
+  // Test shape function that returns the shape of input arg as result shape.
+  func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+    %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+    return %0 : !shape.shape
+  }
+} mapping {
+  test.same_operand_result_type = @same_result_shape
+}
+
+}
+
+// -----
+
+// Test that duplicate op to shape function mappings are flagged (this is
+// more an invariant of using the dictionary attribute here than anything
+// specific to function library op).
+
+module attributes {shape.lib = [@shape_lib]} {
+
+shape.function_library @shape_lib {
+  // Test shape function that returns the shape of input arg as result shape.
+  func @same_result_shape(%arg: !shape.value_shape) -> !shape.shape {
+    %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+    return %0 : !shape.shape
+  }
+} mapping {
+  // expected-error @+2 {{duplicate key}}
+  test.same_operand_result_type = @same_result_shape,
+  test.same_operand_result_type = @same_result_shape
+}
+
+}
+
+// -----
+
+// Test that op referred to by shape lib is a shape function library.
+
+// expected-error at +1 {{required to be shape function library}}
+module attributes {shape.lib = @fn} {
+
+func @fn(%arg: !shape.value_shape) -> !shape.shape {
+  %0 = shape.shape_of %arg : !shape.value_shape -> !shape.shape
+  return %0 : !shape.shape
+}
+
+}
+
+// -----
+
+// Test that op referred to by shape lib is a shape function library.
+
+func @fn(%arg: !shape.value_shape) -> !shape.shape {
+  // expected-error at +1 {{SymbolTable}}
+  %0 = shape.shape_of %arg {shape.lib = @fn} : !shape.value_shape -> !shape.shape
+  return %0 : !shape.shape
+}
+
+// -----
+
+// Test that shape function library is defined.
+
+// expected-error at +1 {{@fn not found}}
+module attributes {shape.lib = @fn} { }

diff  --git a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
index b7127c5edf32..4477eb1eda6d 100644
--- a/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
+++ b/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
@@ -26,41 +26,57 @@ struct ReportShapeFnPass
 void ReportShapeFnPass::runOnOperation() {
   auto module = getOperation();
 
-  // Lookup shape function library.
-  shape::FunctionLibraryOp shapeFnLib = nullptr;
-  for (auto lib : module.getOps<shape::FunctionLibraryOp>()) {
-    if (shapeFnLib) {
-      lib.emitError("duplicate shape library op")
-              .attachNote(shapeFnLib.getLoc())
-          << "previous mapping";
-      return signalPassFailure();
-    }
-    shapeFnLib = lib;
-  };
-
   // Report the shape function available to refine the op.
   auto shapeFnId = Identifier::get("shape.function", &getContext());
-  auto remarkShapeFn = [&](Operation *op) {
+  auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
     if (op->isKnownTerminator())
-      return;
+      return true;
     if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
       op->emitRemark() << "implements InferType op interface";
-    } else if (auto fn = shapeFnLib.getShapeFunction(op)) {
+      return true;
+    }
+    if (auto fn = shapeFnLib.getShapeFunction(op)) {
       op->emitRemark() << "associated shape function: " << fn.getName();
-    } else if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
+      return true;
+    }
+    if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
       auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
       op->emitRemark() << "associated shape function: " << fn.getName();
-    } else {
-      op->emitRemark() << "no associated way to refine shape";
+      return true;
     }
+    return false;
   };
 
+  // Lookup shape function library.
+  SmallVector<shape::FunctionLibraryOp, 4> libraries;
+  auto attr = module.getAttr("shape.lib");
+  if (attr) {
+    auto lookup = [&](Attribute attr) {
+      return cast<shape::FunctionLibraryOp>(
+          SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>()));
+    };
+    if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+      libraries.reserve(arrayAttr.size());
+      for (auto attr : arrayAttr)
+        libraries.push_back(lookup(attr));
+    } else {
+      libraries.reserve(1);
+      libraries.push_back(lookup(attr));
+    }
+  }
+
   module.getBodyRegion().walk([&](FuncOp func) {
     // Skip ops in the shape function library.
     if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
       return;
 
-    func.walk([&](Operation *op) { remarkShapeFn(op); });
+    func.walk([&](Operation *op) {
+      bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
+        return remarkShapeFn(lib, op);
+      });
+      if (!found)
+        op->emitRemark() << "no associated way to refine shape";
+    });
   });
 }
 


        


More information about the llvm-branch-commits mailing list