[Mlir-commits] [mlir] 9ba37b3 - [mlir][ods] Add materialize derived attribute method

Jacques Pienaar llvmlistbot at llvm.org
Mon Apr 20 13:14:08 PDT 2020


Author: Jacques Pienaar
Date: 2020-04-20T13:13:04-07:00
New Revision: 9ba37b3bf294d5478aa4c1f6a96def7204181de6

URL: https://github.com/llvm/llvm-project/commit/9ba37b3bf294d5478aa4c1f6a96def7204181de6
DIFF: https://github.com/llvm/llvm-project/commit/9ba37b3bf294d5478aa4c1f6a96def7204181de6.diff

LOG: [mlir][ods] Add materialize derived attribute method

Summary:
Generate method to generate a DictionaryAttr with attribute values of
derived attribute. If a conversion back from the derived attribute C++
type to Attribute is not defined, then attempting to materialize such an
op's derived attributes would result in runtime failure.

This allows to treat derived attributes and attributes of an op in more
uniform manner where needed. The derived attributes are not added to the
operation but returned as new attribute instead.

Differential Revision: https://reviews.llvm.org/D78302

Added: 
    mlir/test/mlir-tblgen/op-derived-attribute.mlir

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/mlir-tblgen/op-attribute.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 282267daf339..8c6561b73f63 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1419,13 +1419,29 @@ def SymbolRefArrayAttr :
 // Note: All derived attributes should be materializable as an Attribute. E.g.,
 // do not use DerivedAttr for things that could not have been stored as
 // Attribute.
-class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> {
+//
+class DerivedAttr<code ret, code b, code convert = ""> :
+    Attr<CPred<"true">, "derived attribute"> {
   let returnType = ret;
   code body = b;
+
+  // Specify how to convert from the derived attribute to an attibute.
+  //
+  // ## Special placeholders
+  //
+  // Special placeholders can be used to refer to entities during conversion:
+  //
+  // * `$_builder` will be replaced by a mlir::Builder instance.
+  // * `$_ctx` will be replaced by the MLIRContext* instance.
+  // * `$_self` will be replaced with the derived attribute (value produces
+  //    `returnType`).
+  let convertFromStorage = convert;
 }
 
 // Derived attribute that returns a mlir::Type.
-class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
+class DerivedTypeAttr<code body> : DerivedAttr<"Type", body> {
+  let convertFromStorage = "TypeAttr::get($_self)";
+}
 
 //===----------------------------------------------------------------------===//
 // Constant attribute kinds

diff  --git a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
index 68b9c6c39f43..e6f370752bcf 100644
--- a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td
@@ -31,6 +31,14 @@ def DerivedAttributeOpInterface : OpInterface<"DerivedAttributeOpInterface"> {
       /*methodName=*/"isDerivedAttribute",
       /*args=*/(ins "StringRef":$name)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Materializes the derived attributes. Returns null attribute where
+        unable to materialize a derived attribute as attribute.
+      }],
+      /*retTy=*/"DictionaryAttr",
+      /*methodName=*/"materializeDerivedAttributes"
+    >,
   ];
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 0d8b13c61030..5b97e113d684 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -21,15 +21,16 @@ add_llvm_library(MLIRTestDialect
 )
 target_link_libraries(MLIRTestDialect
   PUBLIC
+  LLVMSupport
   MLIRControlFlowInterfaces
+  MLIRDerivedAttributeOpInterface
   MLIRDialect
   MLIRIR
+  MLIRInferTypeOpInterface
   MLIRLinalgTransforms
   MLIRPass
   MLIRStandardOps
   MLIRStandardToStandard
-  MLIRTransforms
   MLIRTransformUtils
-  MLIRInferTypeOpInterface
-  LLVMSupport
+  MLIRTransforms
 )

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 9b4dfee2daaa..b4ca125cb3d6 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffects.h"
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 524780b89552..c633bde2d769 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -236,6 +236,15 @@ def RankedIntElementsAttrOp : TEST_Op<"ranked_int_elements_attr"> {
   );
 }
 
+def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
+  let results = (outs AnyTensor:$output);
+  DerivedTypeAttr element_dtype =
+    DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">;
+  DerivedAttr size = DerivedAttr<"int",
+    "return output().getType().cast<ShapedType>().getSizeInBits();",
+    "$_builder.getI32IntegerAttr($_self)">;
+}
+
 //===----------------------------------------------------------------------===//
 // Test Attribute Constraints
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d36eb985512a..c14dea2b9534 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+
 using namespace mlir;
 
 // Native function for testing NativeCodeCall
@@ -129,6 +130,23 @@ struct TestReturnTypeDriver
 };
 } // end anonymous namespace
 
+namespace {
+struct TestDerivedAttributeDriver
+    : public PassWrapper<TestDerivedAttributeDriver, FunctionPass> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestDerivedAttributeDriver::runOnFunction() {
+  getFunction().walk([](DerivedAttributeOpInterface dOp) {
+    auto dAttr = dOp.materializeDerivedAttributes();
+    if (!dAttr)
+      return;
+    for (auto d : dAttr)
+      dOp.emitRemark() << d.first << " = " << d.second;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // Legalization Driver.
 //===----------------------------------------------------------------------===//
@@ -589,6 +607,9 @@ void registerPatternsTestPass() {
   mlir::PassRegistration<TestReturnTypeDriver>("test-return-type",
                                                "Run return type functions");
 
+  mlir::PassRegistration<TestDerivedAttributeDriver>(
+      "test-derived-attr", "Run test derived attributes");
+
   mlir::PassRegistration<TestPatternDriver>("test-patterns",
                                             "Run test dialect patterns");
 

diff  --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index c38d59cc246a..6e22912d1a2b 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -215,6 +215,7 @@ def DerivedTypeAttrOp : NS_Op<"derived_type_attr_op", []> {
 // DEF:   if (name == "element_dtype") return true;
 // DEF:   return false;
 // DEF: }
+// DEF: DerivedTypeAttrOp::materializeDerivedAttributes
 
 // Test that only default valued attributes at the end of the arguments
 // list get default values in the builder signature

diff  --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
new file mode 100644
index 000000000000..b11df48a319c
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt -test-derived-attr -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
+
+// CHECK-LABEL: verifyDerivedAttributes
+func @verifyDerivedAttributes() {
+  // expected-remark @+2 {{element_dtype = f32}}
+  // expected-remark @+1 {{size = 320}}
+  %0 = "test.derived_type_attr"() : () -> tensor<10xf32>
+  // expected-remark @+2 {{element_dtype = i79}}
+  // expected-remark @+1 {{size = 948}}
+  %1 = "test.derived_type_attr"() : () -> tensor<12xi79>
+
+  return
+}

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 41f392e67f62..7195165a3572 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -396,21 +396,66 @@ void OpEmitter::genAttrGetters() {
     }
   }
 
-  // Generate helper method to query whether a named attribute is a derived
-  // attribute. This enables, for example, avoiding adding an attribute that
-  // overlaps with a derived attribute.
-  auto derivedAttr = make_filter_range(op.getAttributes(),
-                                       [](const NamedAttribute &namedAttr) {
-                                         return namedAttr.attr.isDerivedAttr();
-                                       });
-  if (!derivedAttr.empty()) {
+  auto derivedAttrs = make_filter_range(op.getAttributes(),
+                                        [](const NamedAttribute &namedAttr) {
+                                          return namedAttr.attr.isDerivedAttr();
+                                        });
+  if (!derivedAttrs.empty()) {
     opClass.addTrait("DerivedAttributeOpInterface::Trait");
-    auto &method = opClass.newMethod("bool", "isDerivedAttribute",
-                                     "StringRef name", OpMethod::MP_Static);
-    auto &body = method.body();
-    for (auto namedAttr : derivedAttr)
-      body << "    if (name == \"" << namedAttr.name << "\") return true;\n";
-    body << " return false;";
+    // Generate helper method to query whether a named attribute is a derived
+    // attribute. This enables, for example, avoiding adding an attribute that
+    // overlaps with a derived attribute.
+    {
+      auto &method = opClass.newMethod("bool", "isDerivedAttribute",
+                                       "StringRef name", OpMethod::MP_Static);
+      auto &body = method.body();
+      for (auto namedAttr : derivedAttrs)
+        body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
+      body << " return false;";
+    }
+    // Generate method to materialize derived attributes as a DictionaryAttr.
+    {
+      OpMethod &method =
+          opClass.newMethod("DictionaryAttr", "materializeDerivedAttributes");
+      auto &body = method.body();
+
+      auto nonMaterializable =
+          make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
+            return namedAttr.attr.getConvertFromStorageCall().empty();
+          });
+      if (!nonMaterializable.empty()) {
+        std::string attrs;
+        llvm::raw_string_ostream os(attrs);
+        interleaveComma(nonMaterializable, os,
+                        [&](const NamedAttribute &attr) { os << attr.name; });
+        PrintWarning(
+            op.getLoc(),
+            formatv(
+                "op has non-materialzable derived attributes '{0}', skipping",
+                os.str()));
+        body << formatv("  emitOpError(\"op has non-materializable derived "
+                        "attributes '{0}'\");\n",
+                        attrs);
+        body << "  return nullptr;";
+        return;
+      }
+
+      body << "  MLIRContext* ctx = getContext();\n";
+      body << "  Builder odsBuilder(ctx); (void)odsBuilder;\n";
+      body << "  return DictionaryAttr::get({\n";
+      interleave(
+          derivedAttrs, body,
+          [&](const NamedAttribute &namedAttr) {
+            auto tmpl = namedAttr.attr.getConvertFromStorageCall();
+            body << "    {Identifier::get(\"" << namedAttr.name << "\", ctx),\n"
+                 << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
+                                     .withBuilder("odsBuilder")
+                                     .addSubst("_ctx", "ctx"))
+                 << "}";
+          },
+          ",\n");
+      body << "\n    }, ctx);";
+    }
   }
 }
 
@@ -1115,16 +1160,14 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
     body << "  " << builderOpState
          << ".addAttribute(\"operand_segment_sizes\", "
             "odsBuilder->getI32VectorAttr({";
-    llvm::interleaveComma(
-        llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
-          if (op.getOperand(i).isOptional())
-            body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
-          else if (op.getOperand(i).isVariadic())
-            body << "static_cast<int32_t>(" << getArgumentName(op, i)
-                 << ".size())";
-          else
-            body << "1";
-        });
+    interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
+      if (op.getOperand(i).isOptional())
+        body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
+      else if (op.getOperand(i).isVariadic())
+        body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
+      else
+        body << "1";
+    });
     body << "}));\n";
   }
 
@@ -1222,10 +1265,10 @@ void OpEmitter::genOpInterfaceMethods() {
         continue;
       std::string args;
       llvm::raw_string_ostream os(args);
-      llvm::interleaveComma(method.getArguments(), os,
-                            [&](const OpInterfaceMethod::Argument &arg) {
-                              os << arg.type << " " << arg.name;
-                            });
+      interleaveComma(method.getArguments(), os,
+                      [&](const OpInterfaceMethod::Argument &arg) {
+                        os << arg.type << " " << arg.name;
+                      });
       opClass.newMethod(method.getReturnType(), method.getName(), os.str(),
                         method.isStatic() ? OpMethod::MP_Static
                                           : OpMethod::MP_None,
@@ -1776,7 +1819,7 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
   IfDefScope scope("GET_OP_LIST", os);
 
-  llvm::interleave(
+  interleave(
       // TODO: We are constructing the Operator wrapper instance just for
       // getting it's qualified class name here. Reduce the overhead by having a
       // lightweight version of Operator class just for that purpose.


        


More information about the Mlir-commits mailing list