[Mlir-commits] [mlir] 8c2bff1 - Lazy initialize diagnostic when handling MLIR properties (#65868)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 11 13:22:39 PDT 2023


Author: Mehdi Amini
Date: 2023-09-11T13:22:35-07:00
New Revision: 8c2bff1ab929289aa060d59df90b3bc3752eba32

URL: https://github.com/llvm/llvm-project/commit/8c2bff1ab929289aa060d59df90b3bc3752eba32
DIFF: https://github.com/llvm/llvm-project/commit/8c2bff1ab929289aa060d59df90b3bc3752eba32.diff

LOG: Lazy initialize diagnostic when handling MLIR properties (#65868)

Instead of eagerly creating a diagnostic that will be discarded in the
normal case, switch to lazy initialization on error.

Added: 
    

Modified: 
    mlir/include/mlir/IR/ExtensibleDialect.h
    mlir/include/mlir/IR/ODSSupport.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/Properties.td
    mlir/lib/AsmParser/Parser.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/ODSSupport.cpp
    mlir/lib/IR/Operation.cpp
    mlir/lib/IR/OperationSupport.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestDialect.h
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/unittests/IR/OpPropertiesTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index e4d8d2d6000fc60..37821d3a2a5163f 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -486,10 +486,11 @@ class DynamicOpDefinition : public OperationName::Impl {
   void populateDefaultProperties(OperationName opName,
                                  OpaqueProperties properties) final {}
 
-  LogicalResult setPropertiesFromAttr(OperationName opName,
-                                      OpaqueProperties properties,
-                                      Attribute attr,
-                                      InFlightDiagnostic *diag) final {
+  LogicalResult
+  setPropertiesFromAttr(OperationName opName, OpaqueProperties properties,
+                        Attribute attr,
+                        function_ref<InFlightDiagnostic &()> getDiag) final {
+    getDiag() << "extensible Dialects don't support properties";
     return failure();
   }
   Attribute getPropertiesAsAttr(Operation *op) final { return {}; }

diff  --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h
index 687f764ae95fd99..748bf52a55c557a 100644
--- a/mlir/include/mlir/IR/ODSSupport.h
+++ b/mlir/include/mlir/IR/ODSSupport.h
@@ -14,6 +14,8 @@
 #define MLIR_IR_ODSSUPPORT_H
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
@@ -24,8 +26,9 @@ namespace mlir {
 /// Convert an IntegerAttr attribute to an int64_t, or return an error if the
 /// attribute isn't an IntegerAttr. If the optional diagnostic is provided an
 /// error message is also emitted.
-LogicalResult convertFromAttribute(int64_t &storage, Attribute attr,
-                                   InFlightDiagnostic *diag);
+LogicalResult
+convertFromAttribute(int64_t &storage, Attribute attr,
+                     function_ref<InFlightDiagnostic &()> getDiag);
 
 /// Convert the provided int64_t to an IntegerAttr attribute.
 Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
@@ -34,15 +37,17 @@ Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
 /// storage has the same size as the array. An error is returned if the
 /// attribute isn't a DenseI64ArrayAttr or it does not have the same size. If
 /// the optional diagnostic is provided an error message is also emitted.
-LogicalResult convertFromAttribute(MutableArrayRef<int64_t> storage,
-                                   Attribute attr, InFlightDiagnostic *diag);
+LogicalResult
+convertFromAttribute(MutableArrayRef<int64_t> storage, Attribute attr,
+                     function_ref<InFlightDiagnostic &()> getDiag);
 
 /// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the
 /// storage has the same size as the array. An error is returned if the
 /// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If
 /// the optional diagnostic is provided an error message is also emitted.
-LogicalResult convertFromAttribute(MutableArrayRef<int32_t> storage,
-                                   Attribute attr, InFlightDiagnostic *diag);
+LogicalResult
+convertFromAttribute(MutableArrayRef<int32_t> storage, Attribute attr,
+                     function_ref<InFlightDiagnostic &()> getDiag);
 
 /// Convert the provided ArrayRef<int64_t> to a DenseI64ArrayAttr attribute.
 Attribute convertToAttribute(MLIRContext *ctx, ArrayRef<int64_t> storage);

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 84ba46f4d6f3ec1..895f17dfe1d07c8 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1769,9 +1769,10 @@ class Op : public OpState, public Traits<ConcreteType>... {
   /// the namespace where the properties are defined. It can also be overridden
   /// in the derived ConcreteOp.
   template <typename PropertiesTy>
-  static LogicalResult setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
-                                             InFlightDiagnostic *diag) {
-    return setPropertiesFromAttribute(prop, attr, diag);
+  static LogicalResult
+  setPropertiesFromAttr(PropertiesTy &prop, Attribute attr,
+                        function_ref<InFlightDiagnostic &()> getDiag) {
+    return setPropertiesFromAttribute(prop, attr, getDiag);
   }
   /// Convert the provided properties to an attribute. This default
   /// implementation forwards to a free function `getPropertiesAsAttribute` that

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 361a38e87b6ba32..b815eaf8899d6fc 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -882,8 +882,9 @@ class alignas(8) Operation final
   /// matching the expectations of the properties for this operation. This is
   /// mostly useful for unregistered operations or used when parsing the
   /// generic format. An optional diagnostic can be passed in for richer errors.
-  LogicalResult setPropertiesFromAttribute(Attribute attr,
-                                           InFlightDiagnostic *diagnostic);
+  LogicalResult
+  setPropertiesFromAttribute(Attribute attr,
+                             function_ref<InFlightDiagnostic &()> getDiag);
 
   /// Copy properties from an existing other properties object. The two objects
   /// must be the same type.

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 670dd289c480a30..19ffddc30904897 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -136,9 +136,9 @@ class OperationName {
     virtual void deleteProperties(OpaqueProperties) = 0;
     virtual void populateDefaultProperties(OperationName opName,
                                            OpaqueProperties properties) = 0;
-    virtual LogicalResult setPropertiesFromAttr(OperationName, OpaqueProperties,
-                                                Attribute,
-                                                InFlightDiagnostic *) = 0;
+    virtual LogicalResult
+    setPropertiesFromAttr(OperationName, OpaqueProperties, Attribute,
+                          function_ref<InFlightDiagnostic &()> getDiag) = 0;
     virtual Attribute getPropertiesAsAttr(Operation *) = 0;
     virtual void copyProperties(OpaqueProperties, OpaqueProperties) = 0;
     virtual llvm::hash_code hashProperties(OpaqueProperties) = 0;
@@ -216,8 +216,9 @@ class OperationName {
     void deleteProperties(OpaqueProperties) final;
     void populateDefaultProperties(OperationName opName,
                                    OpaqueProperties properties) final;
-    LogicalResult setPropertiesFromAttr(OperationName, OpaqueProperties,
-                                        Attribute, InFlightDiagnostic *) final;
+    LogicalResult
+    setPropertiesFromAttr(OperationName, OpaqueProperties, Attribute,
+                          function_ref<InFlightDiagnostic &()> getDiag) final;
     Attribute getPropertiesAsAttr(Operation *) final;
     void copyProperties(OpaqueProperties, OpaqueProperties) final;
     llvm::hash_code hashProperties(OpaqueProperties) final;
@@ -434,12 +435,10 @@ class OperationName {
   }
 
   /// Define the op properties from the provided Attribute.
-  LogicalResult
-  setOpPropertiesFromAttribute(OperationName opName,
-                               OpaqueProperties properties, Attribute attr,
-                               InFlightDiagnostic *diagnostic) const {
-    return getImpl()->setPropertiesFromAttr(opName, properties, attr,
-                                            diagnostic);
+  LogicalResult setOpPropertiesFromAttribute(
+      OperationName opName, OpaqueProperties properties, Attribute attr,
+      function_ref<InFlightDiagnostic &()> getDiag) const {
+    return getImpl()->setPropertiesFromAttr(opName, properties, attr, getDiag);
   }
 
   void copyOpProperties(OpaqueProperties lhs, OpaqueProperties rhs) const {
@@ -628,16 +627,15 @@ class RegisteredOperationName : public OperationName {
                                               *properties.as<Properties *>());
     }
 
-    LogicalResult setPropertiesFromAttr(OperationName opName,
-                                        OpaqueProperties properties,
-                                        Attribute attr,
-                                        InFlightDiagnostic *diag) final {
+    LogicalResult
+    setPropertiesFromAttr(OperationName opName, OpaqueProperties properties,
+                          Attribute attr,
+                          function_ref<InFlightDiagnostic &()> getDiag) final {
       if constexpr (hasProperties) {
         auto p = properties.as<Properties *>();
-        return ConcreteOp::setPropertiesFromAttr(*p, attr, diag);
+        return ConcreteOp::setPropertiesFromAttr(*p, attr, getDiag);
       }
-      if (diag)
-        *diag << "This operation does not support properties";
+      getDiag() << "this operation does not support properties";
       return failure();
     }
     Attribute getPropertiesAsAttr(Operation *op) final {
@@ -997,8 +995,9 @@ struct OperationState {
 
   // Set the properties defined on this OpState on the given operation,
   // optionally emit diagnostics on error through the provided diagnostic.
-  LogicalResult setProperties(Operation *op,
-                              InFlightDiagnostic *diagnostic) const;
+  LogicalResult
+  setProperties(Operation *op,
+                function_ref<InFlightDiagnostic &()> getDiag) const;
 
   void addOperands(ValueRange newOperands);
 

diff  --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td
index 3d7de7e4f460e51..99da1763524fa94 100644
--- a/mlir/include/mlir/IR/Properties.td
+++ b/mlir/include/mlir/IR/Properties.td
@@ -55,7 +55,7 @@ class Property<string storageTypeParam = "", string desc = ""> {
   // Format:
   // - `$_storage` is the storage type value.
   // - `$_attr` is the attribute.
-  // - `$_diag` is an optional Diagnostic pointer to emit error.
+  // - `$_diag` is a callback to get a Diagnostic to emit error.
   //
   // The expression must return a LogicalResult
   code convertFromAttribute = [{

diff  --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 02bf9a418063991..84f44dba806df01 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -18,8 +18,10 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/Support/InterfaceSupport.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringSet.h"
@@ -29,6 +31,7 @@
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/SourceMgr.h"
 #include <algorithm>
+#include <memory>
 #include <optional>
 
 using namespace mlir;
@@ -1443,12 +1446,17 @@ Operation *OperationParser::parseGenericOperation() {
   // Try setting the properties for the operation, using a diagnostic to print
   // errors.
   if (properties) {
-    InFlightDiagnostic diagnostic =
-        mlir::emitError(srcLocation, "invalid properties ")
-        << properties << " for op " << name << ": ";
-    if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
+    std::unique_ptr<InFlightDiagnostic> diagnostic;
+    auto getDiag = [&]() -> InFlightDiagnostic & {
+      if (!diagnostic) {
+        diagnostic = std::make_unique<InFlightDiagnostic>(
+            mlir::emitError(srcLocation, "invalid properties ")
+            << properties << " for op " << name << ": ");
+      }
+      return *diagnostic;
+    };
+    if (failed(op->setPropertiesFromAttribute(properties, getDiag)))
       return nullptr;
-    diagnostic.abandon();
   }
 
   return op;
@@ -2001,12 +2009,18 @@ OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
 
   // Try setting the properties for the operation.
   if (properties) {
-    InFlightDiagnostic diagnostic =
-        mlir::emitError(srcLocation, "invalid properties ")
-        << properties << " for op " << op->getName().getStringRef() << ": ";
-    if (failed(op->setPropertiesFromAttribute(properties, &diagnostic)))
+    std::unique_ptr<InFlightDiagnostic> diagnostic;
+    auto getDiag = [&]() -> InFlightDiagnostic & {
+      if (!diagnostic) {
+        diagnostic = std::make_unique<InFlightDiagnostic>(
+            mlir::emitError(srcLocation, "invalid properties ")
+            << properties << " for op " << op->getName().getStringRef()
+            << ": ");
+      }
+      return *diagnostic;
+    };
+    if (failed(op->setPropertiesFromAttribute(properties, getDiag)))
       return nullptr;
-    diagnostic.abandon();
   }
   return op;
 }

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ccdae1424998567..ef234a912490eea 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -370,16 +370,21 @@ static LogicalResult inferOperationTypes(OperationState &state) {
   if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) {
     auto prop = std::make_unique<char[]>(info->getOpPropertyByteSize());
     properties = OpaqueProperties(prop.get());
-    InFlightDiagnostic diag = emitError(state.location)
-                              << " failed properties conversion while building "
-                              << state.name.getStringRef() << " with `"
-                              << attributes << "`: ";
-    if (failed(info->setOpPropertiesFromAttribute(state.name, properties,
-                                                  attributes, &diag))) {
-      return failure();
+    if (properties) {
+      std::unique_ptr<InFlightDiagnostic> diagnostic;
+      auto getDiag = [&]() -> InFlightDiagnostic & {
+        if (!diagnostic) {
+          diagnostic = std::make_unique<InFlightDiagnostic>(
+              emitError(state.location)
+              << " failed properties conversion while building "
+              << state.name.getStringRef() << " with `" << attributes << "`: ");
+        }
+        return *diagnostic;
+      };
+      if (failed(info->setOpPropertiesFromAttribute(state.name, properties,
+                                                    attributes, getDiag)))
+        return failure();
     }
-    diag.abandon();
-
     if (succeeded(inferInterface->inferReturnTypes(
             context, state.location, state.operands, attributes, properties,
             state.regions, state.types))) {

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index e19c3d4d54179b3..5f1d036d22b918e 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -852,7 +852,7 @@ void OperationName::UnregisteredOpModel::populateDefaultProperties(
     OperationName opName, OpaqueProperties properties) {}
 LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr(
     OperationName opName, OpaqueProperties properties, Attribute attr,
-    InFlightDiagnostic *diag) {
+    function_ref<InFlightDiagnostic &()> getDiag) {
   *properties.as<Attribute *>() = attr;
   return success();
 }

diff  --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp
index f67e7dbf38592e2..0601430c4461651 100644
--- a/mlir/lib/IR/ODSSupport.cpp
+++ b/mlir/lib/IR/ODSSupport.cpp
@@ -18,13 +18,12 @@
 
 using namespace mlir;
 
-LogicalResult mlir::convertFromAttribute(int64_t &storage,
-                                         ::mlir::Attribute attr,
-                                         ::mlir::InFlightDiagnostic *diag) {
+LogicalResult
+mlir::convertFromAttribute(int64_t &storage, Attribute attr,
+                           function_ref<InFlightDiagnostic &()> getDiag) {
   auto valueAttr = dyn_cast<IntegerAttr>(attr);
   if (!valueAttr) {
-    if (diag)
-      *diag << "expected IntegerAttr for key `value`";
+    getDiag() << "expected IntegerAttr for key `value`";
     return failure();
   }
   storage = valueAttr.getValue().getSExtValue();
@@ -35,35 +34,33 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) {
 }
 
 template <typename DenseArrayTy, typename T>
-LogicalResult convertDenseArrayFromAttr(MutableArrayRef<T> storage,
-                                        ::mlir::Attribute attr,
-                                        ::mlir::InFlightDiagnostic *diag,
-                                        StringRef denseArrayTyStr) {
+LogicalResult
+convertDenseArrayFromAttr(MutableArrayRef<T> storage, Attribute attr,
+                          function_ref<InFlightDiagnostic &()> getDiag,
+                          StringRef denseArrayTyStr) {
   auto valueAttr = dyn_cast<DenseArrayTy>(attr);
   if (!valueAttr) {
-    if (diag)
-      *diag << "expected " << denseArrayTyStr << " for key `value`";
+    getDiag() << "expected " << denseArrayTyStr << " for key `value`";
     return failure();
   }
   if (valueAttr.size() != static_cast<int64_t>(storage.size())) {
-    if (diag)
-      *diag << "size mismatch in attribute conversion: " << valueAttr.size()
-            << " vs " << storage.size();
+    getDiag() << "size mismatch in attribute conversion: " << valueAttr.size()
+              << " vs " << storage.size();
     return failure();
   }
   llvm::copy(valueAttr.asArrayRef(), storage.begin());
   return success();
 }
-LogicalResult mlir::convertFromAttribute(MutableArrayRef<int64_t> storage,
-                                         ::mlir::Attribute attr,
-                                         ::mlir::InFlightDiagnostic *diag) {
-  return convertDenseArrayFromAttr<DenseI64ArrayAttr>(storage, attr, diag,
+LogicalResult
+mlir::convertFromAttribute(MutableArrayRef<int64_t> storage, Attribute attr,
+                           function_ref<InFlightDiagnostic &()> getDiag) {
+  return convertDenseArrayFromAttr<DenseI64ArrayAttr>(storage, attr, getDiag,
                                                       "DenseI64ArrayAttr");
 }
-LogicalResult mlir::convertFromAttribute(MutableArrayRef<int32_t> storage,
-                                         Attribute attr,
-                                         InFlightDiagnostic *diag) {
-  return convertDenseArrayFromAttr<DenseI32ArrayAttr>(storage, attr, diag,
+LogicalResult
+mlir::convertFromAttribute(MutableArrayRef<int32_t> storage, Attribute attr,
+                           function_ref<InFlightDiagnostic &()> getDiag) {
+  return convertDenseArrayFromAttr<DenseI32ArrayAttr>(storage, attr, getDiag,
                                                       "DenseI32ArrayAttr");
 }
 

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index ef98a89f4bb49b6..aa577aa089c6860 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -351,16 +351,15 @@ Attribute Operation::getPropertiesAsAttribute() {
     return *getPropertiesStorage().as<Attribute *>();
   return info->getOpPropertiesAsAttribute(this);
 }
-LogicalResult
-Operation::setPropertiesFromAttribute(Attribute attr,
-                                      InFlightDiagnostic *diagnostic) {
+LogicalResult Operation::setPropertiesFromAttribute(
+    Attribute attr, function_ref<InFlightDiagnostic &()> getDiag) {
   std::optional<RegisteredOperationName> info = getRegisteredInfo();
   if (LLVM_UNLIKELY(!info)) {
     *getPropertiesStorage().as<Attribute *>() = attr;
     return success();
   }
   return info->setOpPropertiesFromAttribute(
-      this->getName(), this->getPropertiesStorage(), attr, diagnostic);
+      this->getName(), this->getPropertiesStorage(), attr, getDiag);
 }
 
 void Operation::copyProperties(OpaqueProperties rhs) {

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 79cc38da051ee1d..0cb6a1cd191b161 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -198,12 +198,11 @@ OperationState::~OperationState() {
     propertiesDeleter(properties);
 }
 
-LogicalResult
-OperationState::setProperties(Operation *op,
-                              InFlightDiagnostic *diagnostic) const {
+LogicalResult OperationState::setProperties(
+    Operation *op, function_ref<InFlightDiagnostic &()> getDiag) const {
   if (LLVM_UNLIKELY(propertiesAttr)) {
     assert(!properties);
-    return op->setPropertiesFromAttribute(propertiesAttr, diagnostic);
+    return op->setPropertiesFromAttribute(propertiesAttr, getDiag);
   }
   if (properties)
     propertiesSetter(op->getPropertiesStorage(), properties);

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ae4c9a85605e1c5..55cf4246562e6af 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -28,9 +28,11 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
@@ -50,12 +52,12 @@ using namespace test;
 Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const {
   return StringAttr::get(ctx, content);
 }
-LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
-                                        InFlightDiagnostic *diag) {
+LogicalResult
+MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr,
+                          function_ref<InFlightDiagnostic &()> getDiag) {
   StringAttr strAttr = dyn_cast<StringAttr>(attr);
   if (!strAttr) {
-    if (diag)
-      *diag << "Expect StringAttr but got " << attr;
+    getDiag() << "Expect StringAttr but got " << attr;
     return failure();
   }
   prop.content = strAttr.getValue();
@@ -103,9 +105,9 @@ static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer,
     writer.writeVarInt(elt);
 }
 
-static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop,
-                                                Attribute attr,
-                                                InFlightDiagnostic *diagnostic);
+static LogicalResult
+setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
+                           function_ref<InFlightDiagnostic &()> getDiag);
 static DictionaryAttr
 getPropertiesAsAttribute(MLIRContext *ctx,
                          const PropertiesWithCustomPrint &prop);
@@ -114,9 +116,9 @@ static void customPrintProperties(OpAsmPrinter &p,
                                   const PropertiesWithCustomPrint &prop);
 static ParseResult customParseProperties(OpAsmParser &parser,
                                          PropertiesWithCustomPrint &prop);
-static LogicalResult setPropertiesFromAttribute(VersionedProperties &prop,
-                                                Attribute attr,
-                                                InFlightDiagnostic *diagnostic);
+static LogicalResult
+setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
+                           function_ref<InFlightDiagnostic &()> getDiag);
 static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx,
                                                const VersionedProperties &prop);
 static llvm::hash_code computeHash(const VersionedProperties &prop);
@@ -1135,23 +1137,20 @@ OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
 
 static LogicalResult
 setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr,
-                           InFlightDiagnostic *diagnostic) {
+                           function_ref<InFlightDiagnostic &()> getDiag) {
   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
   if (!dict) {
-    if (diagnostic)
-      *diagnostic << "expected DictionaryAttr to set TestProperties";
+    getDiag() << "expected DictionaryAttr to set TestProperties";
     return failure();
   }
   auto label = dict.getAs<mlir::StringAttr>("label");
   if (!label) {
-    if (diagnostic)
-      *diagnostic << "expected StringAttr for key `label`";
+    getDiag() << "expected StringAttr for key `label`";
     return failure();
   }
   auto valueAttr = dict.getAs<IntegerAttr>("value");
   if (!valueAttr) {
-    if (diagnostic)
-      *diagnostic << "expected IntegerAttr for key `value`";
+    getDiag() << "expected IntegerAttr for key `value`";
     return failure();
   }
 
@@ -1187,23 +1186,20 @@ static ParseResult customParseProperties(OpAsmParser &parser,
 }
 static LogicalResult
 setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr,
-                           InFlightDiagnostic *diagnostic) {
+                           function_ref<InFlightDiagnostic &()> getDiag) {
   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
   if (!dict) {
-    if (diagnostic)
-      *diagnostic << "expected DictionaryAttr to set VersionedProperties";
+    getDiag() << "expected DictionaryAttr to set VersionedProperties";
     return failure();
   }
   auto value1Attr = dict.getAs<IntegerAttr>("value1");
   if (!value1Attr) {
-    if (diagnostic)
-      *diagnostic << "expected IntegerAttr for key `value1`";
+    getDiag() << "expected IntegerAttr for key `value1`";
     return failure();
   }
   auto value2Attr = dict.getAs<IntegerAttr>("value2");
   if (!value2Attr) {
-    if (diagnostic)
-      *diagnostic << "expected IntegerAttr for key `value2`";
+    getDiag() << "expected IntegerAttr for key `value2`";
     return failure();
   }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 31a29cc7f9f7aa6..0ae0d47615776dd 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -93,9 +93,9 @@ class MyPropStruct {
   // These three methods are invoked through the  `MyStructProperty` wrapper
   // defined in TestOps.td
   mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const;
-  static mlir::LogicalResult setFromAttr(MyPropStruct &prop,
-                                         mlir::Attribute attr,
-                                         mlir::InFlightDiagnostic *diag);
+  static mlir::LogicalResult
+  setFromAttr(MyPropStruct &prop, mlir::Attribute attr,
+              llvm::function_ref<mlir::InFlightDiagnostic &()> getDiag);
   llvm::hash_code hash() const;
   bool operator==(const MyPropStruct &rhs) const {
     return content == rhs.content;

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index c0dfce553905a0e..ad4f53c5af3cff4 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -52,7 +52,7 @@ static const char *const builderOpState = "odsState";
 static const char *const propertyStorage = "propStorage";
 static const char *const propertyValue = "propValue";
 static const char *const propertyAttr = "propAttr";
-static const char *const propertyDiag = "propDiag";
+static const char *const propertyDiag = "getDiag";
 
 /// The names of the implicit attributes that contain variadic operand and
 /// result segment sizes.
@@ -1212,7 +1212,9 @@ void OpEmitter::genPropertiesSupport() {
               "::mlir::LogicalResult", "setPropertiesFromAttr",
               MethodParameter("Properties &", "prop"),
               MethodParameter("::mlir::Attribute", "attr"),
-              MethodParameter("::mlir::InFlightDiagnostic *", "diag"))
+              MethodParameter(
+                  "::llvm::function_ref<::mlir::InFlightDiagnostic &()>",
+                  "getDiag"))
           ->body();
   auto &getPropMethod =
       opClass
@@ -1264,8 +1266,7 @@ void OpEmitter::genPropertiesSupport() {
   setPropMethod << R"decl(
   ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
   if (!dict) {
-    if (diag)
-      *diag << "expected DictionaryAttr to set properties";
+    getDiag() << "expected DictionaryAttr to set properties";
     return ::mlir::failure();
   }
     )decl";
@@ -1273,17 +1274,16 @@ void OpEmitter::genPropertiesSupport() {
   const char *propFromAttrFmt = R"decl(;
     {{
       auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
-                             ::mlir::InFlightDiagnostic *propDiag) {{
+               ::llvm::function_ref<::mlir::InFlightDiagnostic &()> getDiag) {{
         {0};
       };
       {2};
       if (!attr) {{
-        if (diag)
-          *diag << "expected key entry for {1} in DictionaryAttr to set "
+        getDiag() << "expected key entry for {1} in DictionaryAttr to set "
                    "Properties.";
         return ::mlir::failure();
       }
-      if (::mlir::failed(setFromAttr(prop.{1}, attr, diag)))
+      if (::mlir::failed(setFromAttr(prop.{1}, attr, getDiag)))
         return ::mlir::failure();
     }
 )decl";
@@ -1338,8 +1338,7 @@ void OpEmitter::genPropertiesSupport() {
     {2}
     if (attr || /*isRequired=*/{1}) {{
       if (!attr) {{
-        if (diag)
-          *diag << "expected key entry for {0} in DictionaryAttr to set "
+        getDiag() << "expected key entry for {0} in DictionaryAttr to set "
                    "Properties.";
         return ::mlir::failure();
       }
@@ -1347,8 +1346,7 @@ void OpEmitter::genPropertiesSupport() {
       if (convertedAttr) {{
         propStorage = convertedAttr;
       } else {{
-        if (diag)
-          *diag << "Invalid attribute `{0}` in property conversion: " << attr;
+        getDiag() << "Invalid attribute `{0}` in property conversion: " << attr;
         return ::mlir::failure();
       }
     }

diff  --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
index 21ea4488e7190f8..2d272dfb558c81c 100644
--- a/mlir/unittests/IR/OpPropertiesTest.cpp
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -33,38 +33,33 @@ struct TestProperties {
 /// parsing with the generic format.
 static LogicalResult
 setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
-                           InFlightDiagnostic *diagnostic) {
+                           function_ref<InFlightDiagnostic &()> getDiag) {
   DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
   if (!dict) {
-    if (diagnostic)
-      *diagnostic << "expected DictionaryAttr to set TestProperties";
+    getDiag() << "expected DictionaryAttr to set TestProperties";
     return failure();
   }
   auto aAttr = dict.getAs<IntegerAttr>("a");
   if (!aAttr) {
-    if (diagnostic)
-      *diagnostic << "expected IntegerAttr for key `a`";
+    getDiag() << "expected IntegerAttr for key `a`";
     return failure();
   }
   auto bAttr = dict.getAs<FloatAttr>("b");
   if (!bAttr ||
       &bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
-    if (diagnostic)
-      *diagnostic << "expected FloatAttr for key `b`";
+    getDiag() << "expected FloatAttr for key `b`";
     return failure();
   }
 
   auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
   if (!arrayAttr) {
-    if (diagnostic)
-      *diagnostic << "expected DenseI64ArrayAttr for key `array`";
+    getDiag() << "expected DenseI64ArrayAttr for key `array`";
     return failure();
   }
 
   auto label = dict.getAs<mlir::StringAttr>("label");
   if (!label) {
-    if (diagnostic)
-      *diagnostic << "expected StringAttr for key `label`";
+    getDiag() << "expected StringAttr for key `label`";
     return failure();
   }
 
@@ -257,8 +252,15 @@ TEST(OpPropertiesTest, FailedProperties) {
   attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo")));
   state.propertiesAttr = attrs.getDictionary(&context);
   {
-    auto diag = op->emitError("setting properties failed: ");
-    auto result = state.setProperties(op, &diag);
+    std::unique_ptr<InFlightDiagnostic> diagnostic;
+    auto getDiag = [&]() -> InFlightDiagnostic & {
+      if (!diagnostic) {
+        diagnostic = std::make_unique<InFlightDiagnostic>(
+            op->emitError("setting properties failed: "));
+      }
+      return *diagnostic;
+    };
+    auto result = state.setProperties(op, getDiag);
     EXPECT_TRUE(result.failed());
   }
   EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",


        


More information about the Mlir-commits mailing list