[Mlir-commits] [mlir] 4c640e4 - [mlir][linalg] Verify indexing map required attributes

Lei Zhang llvmlistbot at llvm.org
Tue Feb 9 05:51:32 PST 2021


Author: Lei Zhang
Date: 2021-02-09T08:48:29-05:00
New Revision: 4c640e49c9553363bc0e6fcbdbfe8d678683db97

URL: https://github.com/llvm/llvm-project/commit/4c640e49c9553363bc0e6fcbdbfe8d678683db97
DIFF: https://github.com/llvm/llvm-project/commit/4c640e49c9553363bc0e6fcbdbfe8d678683db97.diff

LOG: [mlir][linalg] Verify indexing map required attributes

Indexing maps for named ops can reference attributes so that
we can synthesize the indexing map dynamically. This supports
cases like strides for convolution ops. However, it does cause
an issue: now the indexing_maps() function call is dependent
on those attributes.

Linalg ops inherit LinalgOpInterfaceTraits, which calls
verifyStructuredOpInterface() to verify the interface.
verifyStructuredOpInterface() further calls indexing_maps().
Note that trait verification is done before the op itself,
where ODS generates the verification for those attributes.
So we can have indexing_maps() referencing non-existing or
invalid attribute, before the ODS-generated verification
kick in.

There isn't a dependency handling mechansim for traits.
This commit adds new interface methods to query whether an
op hasDynamicIndexingMaps() and then perform
verifyIndexingMapRequiredAttributes() in
verifyStructuredOpInterface() to handle the dependency issue.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96297

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index dd15f2278650..c26f02208215 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -813,6 +813,29 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return $_op.iterator_types();
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return true if the indexing map is depending on the current op instance.
+        This means that the indexing map is dynamically synthesized by using the
+        op instance's concrete attributes, instead of being static for all
+        instances of the same op kind.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasDynamicIndexingMaps",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Verify all attributes used by indexing maps are valid.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"verifyIndexingMapRequiredAttributes",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return success(); }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the indexing maps attribute within the current operation.

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index d256942e8d02..1f511a4ffcc4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -302,6 +302,12 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   if (op->getNumResults() > linalgOp.getNumOutputTensors())
     return op->emitError("unexpected #results > #outputs");
 
+  // Before checking indexing maps, we need to make sure the attributes
+  // referenced by it are valid.
+  if (linalgOp.hasDynamicIndexingMaps())
+    if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
+      return failure();
+
   // All shaped operands must be indexed.
   if (linalgOp.indexing_maps().size() != linalgOp.getNumShapedOperands())
     return linalgOp.emitOpError("expected the number of indexing_map (")

diff  --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
index 9bd6152f07da..a16a2b85a9ec 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -92,6 +92,18 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
 // ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
 // ODS: OptionalAttr<F32>:$optional_attr
 //
+// ODS: bool hasDynamicIndexingMaps();
+// ODS: LogicalResult verifyIndexingMapRequiredAttributes();
+//
+// IMPL: bool Test4Op::hasDynamicIndexingMaps() { return true; }
+// IMPL: LogicalResult Test4Op::verifyIndexingMapRequiredAttributes()
+// IMPL:   op->getAttrOfType<ArrayAttr>("array_attr")
+// IMPL:   op->getAttr("f32_attr")
+// IMPL:   op->getAttrOfType<DenseElementsAttr>("fvec_attr")
+// IMPL:   op->getAttr("i32_attr")
+// IMPL:   op->getAttr("i64_attr")
+// IMPL:   op->getAttrOfType<DenseElementsAttr>("ivec_attr")
+//
 ods_def<Test4Op> :
 def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
 attr(

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index a222d67bcb4e..0934967f516c 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1126,6 +1126,15 @@ class TCParser {
   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
                                ComprehensionParsingState &state);
 
+  /// Print methods related to indexing map required attributes.
+  ///
+  /// Specifically, this prints the definitions for the following methods:
+  ///   bool hasDynamicIndexingMaps();
+  ///   LogicalResult verifyIndexingMapRequiredAttributes();
+  void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os,
+                                           StringRef cppOpName,
+                                           ComprehensionParsingState &state);
+
   /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
   void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
                                   ComprehensionParsingState &state);
@@ -1770,6 +1779,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
     std::string extraMethods;
     llvm::raw_string_ostream ss(extraMethods);
     printReferenceIterators(ss, cppOpName, state);
+    printIndexingMapRequiredAttrMethods(ss, cppOpName, state);
     printReferenceIndexingMaps(ss, cppOpName, state);
     printRegionBuilder(ss, cppOpName, state);
     printCanonicalizersAndFolders(ss, cppOpName);
@@ -1827,6 +1837,15 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
   if (!attrList.empty())
     attrList = ",\n" + attrList;
 
+  // Template for Linalg named ops' ODS definitions. Parameters:
+  // {0}: ODS/C++ op name
+  // {1}: assembly op mnemonic
+  // {2}: op interface list
+  // {3}: documentation (summary + description)
+  // {4}: op attribute list
+  // {5}: the number of arguments for the op region
+  // {6}: builder methods taking standalone attribute parameters
+  // {7}: additional methods for attributes used by indexing maps
   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -1906,6 +1925,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         std::string getLibraryCallName() {{
           return generateLibraryCallName(getOperation());
         }
+
+        {7}
       }];
   })FMT";
 
@@ -1971,9 +1992,18 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList);
   }
 
+  std::string attrMethods;
+  if (!registeredAttrs.empty()) {
+    attrMethods = R"(
+      bool hasDynamicIndexingMaps();
+      LogicalResult verifyIndexingMapRequiredAttributes();
+    )";
+  }
+
   // Finally put everything together.
   os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
-                      attrList, state.orderedTensorArgs.size(), attrBuilder);
+                      attrList, state.orderedTensorArgs.size(), attrBuilder,
+                      attrMethods);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2032,6 +2062,111 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
   os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
 }
 
+// Prints methods for querying whether the current named op has attributes that
+// are used by its indexing maps and for verifying those attributes have the
+// expected type.
+void TCParser::printIndexingMapRequiredAttrMethods(
+    llvm::raw_ostream &os, StringRef cppOpName,
+    ComprehensionParsingState &state) {
+  // If there are no attribute used by the whole definition, then we are done.
+  if (registeredAttrs.empty())
+    return;
+
+  // Otherwise, go through each attribute and generate code to verify it's
+  // valid per the spec.
+  SmallVector<std::string, 4> attributes;
+  for (const auto &attr : registeredAttrs) {
+    if (attr.second.isOptional)
+      continue;
+
+    llvm::StringRef name = attr.first;
+    llvm::StringRef elementType = attr.second.elementType;
+    const auto &dims = attr.second.vectorDims;
+
+    // Get the method call to check the element type is of the expected kind.
+    std::string elemTypeCheck = llvm::StringSwitch<std::string>(elementType)
+                                    .Case("f32", "isF32()")
+                                    .Case("i32", "isInteger(32)")
+                                    .Case("i64", "isInteger(64)")
+                                    .Default("");
+    if (elemTypeCheck.empty()) {
+      (void)parser.emitError(
+          "unimplemented support for attribute element type: " + elementType);
+      return;
+    }
+
+    // Scalar case.
+    if (dims.empty() && !attr.second.isArray) {
+      const char *attrFmt = R"FMT(
+        if (auto attr = op->getAttr("{0}")) {{
+          if (!attr.getType().{1}) return op->emitError(
+            "incorrect type for indexing map required attribute '{0}'");
+        } else {{
+          return op->emitError(
+            "missing indexing map required attribute '{0}'");
+        }
+      )FMT";
+
+      attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
+      continue;
+    }
+
+    // Vector case.
+    if (!dims.empty()) {
+      SmallVector<std::string, 4> dimStrs;
+      for (uint64_t dim : dims)
+        dimStrs.push_back(std::to_string(dim));
+
+      const char *attrFmt = R"FMT(
+        if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
+          if (!attr.getType().getElementType().{1}) return op->emitError(
+            "incorrect element type for indexing map required attribute '{0}'");
+          if (attr.getType().getShape() != ArrayRef<int64_t>{{ {2} })
+            return op->emitError(
+              "incorrect shape for indexing map required attribute '{0}'");
+        } else {
+          return op->emitError(
+            "missing indexing map required attribute '{0}'");
+        }
+      )FMT";
+
+      attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck,
+                                         llvm::join(dimStrs, ", ")));
+      continue;
+    }
+
+    // Array case.
+    {
+      const char *attrFmt = R"FMT(
+        if (auto attr = op->getAttrOfType<ArrayAttr>("{0}")) {{
+          for (Attribute element : attr) {{
+            if (!element.getType().{1}) return emitError(
+              "incorrect element type for indexing map required attribute '{0}'");
+          }
+        } else {{
+          return op->emitError(
+            "missing indexing map required attribute '{0}'");
+        }
+      )FMT";
+
+      attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
+    }
+  }
+
+  const char *methodFmt = R"FMT(
+  bool {0}::hasDynamicIndexingMaps() {{ return true; }
+
+  LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
+    Operation *op = getOperation();
+    {1}
+    return success();
+  }
+  )FMT";
+
+  // Print everything out.
+  os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n"));
+}
+
 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
                                           StringRef cppOpName,


        


More information about the Mlir-commits mailing list