[Mlir-commits] [mlir] [mlir][ods] Populate properties in generated builder (PR #90430)

Jacques Pienaar llvmlistbot at llvm.org
Thu May 9 21:13:40 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/90430

>From df1c7c07e2d942a0fd78d77000c4e03bb0bfe060 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 10 May 2024 04:05:45 +0000
Subject: [PATCH] [mlir][ods] Populate properties in generated builder

Previously this was only populated in the create method later. This
resolves some of invalid builder paths. This may also be sufficient that
type inference functions no longer have to consider whether property
conversion has happened (but haven't verified that yet).

This also makes Attributes corresponding to Properties as optional
inside the set from attributes method. Today that is in effect what
happens with Property value initialization and folks use it to define
custom C++ types whose default initialization is what they want. This is
the behavior users get if they use properties directly. Propagating
Attributes without allowing partial setting would require iterating over
the dictionary attribute considering the properties of the op type that
will be created. This could also have been an additional method
generated or optional behavior on the set method. But doing it
consistently seems better. In terms of whats lost, it doesn't seem like
anything compared to the pure Property path where Property is default
value initialized and then partially overwritten (this doesn't seem to
buy anything else verification wise).

Default valued Properties (as specified ODS side rather than C++ side)
triggered error as the containing class was not yet complete but
referenced nested class, so that we couldn't have default initializer
for them in the parent class. Added an additional forwarding builder to
avoid needing to update call sites. This could be split out to separate
change.

Inlined templated function in unit test that was only used 1x. Moved
initialization earlier where seen.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   4 +-
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |   2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |   2 +-
 mlir/include/mlir/IR/Operation.h              |   5 +-
 .../Func/Transforms/OneToNFuncConversions.cpp |   6 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |   8 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  12 +-
 mlir/test/Dialect/OpenMP/invalid.mlir         |  14 +--
 mlir/test/Dialect/OpenMP/ops.mlir             |   4 +-
 mlir/test/lib/Dialect/Test/TestOps.td         |   7 ++
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 104 ++++++++++++----
 mlir/unittests/TableGen/OpBuildGen.cpp        | 114 ++++++++++++++----
 12 files changed, 207 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6655ce6f123e1..7f394972a697b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -65,13 +65,13 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
   let builders = [
     OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
-      build($_builder, $_state, type, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+      build($_builder, $_state, type, lhs, rhs);
     }]>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
-      build($_builder, $_state, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+      build($_builder, $_state, lhs, rhs);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 5738b6ca51c12..63e6ed059deb1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1764,9 +1764,9 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       "ArrayRef<ReassociationIndices>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      build($_builder, $_state, resultType, src, attrs);
       $_state.addAttribute("reassociation",
                           getReassociationIndicesAttribute($_builder, reassociation));
+      build($_builder, $_state, resultType, src, attrs);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..cafc3d91fd1e9 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1216,9 +1216,9 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       "ArrayRef<ReassociationIndices>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      build($_builder, $_state, resultType, src, attrs);
       $_state.addAttribute("reassociation",
           getReassociationIndicesAttribute($_builder, reassociation));
+      build($_builder, $_state, resultType, src, attrs);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index c52a6fcac10c1..f0dd7c5178056 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -916,11 +916,12 @@ class alignas(8) Operation final
   /// operation. Returns an empty attribute if no properties are present.
   Attribute getPropertiesAsAttribute();
 
-  /// Set the properties from the provided  attribute.
+  /// Set the properties from the provided attribute.
   /// This is an expensive operation that can fail if the attribute is not
   /// matching the expectations of the properties for this operation. This is
   /// mostly useful for unregistered operations or used when parsing the
-  /// generic format. An optional diagnostic can be passed in for richer errors.
+  /// generic format. An optional diagnostic emitter can be passed in for richer
+  /// errors, if none is passed then behavior is undefined in error case.
   LogicalResult
   setPropertiesFromAttribute(Attribute attr,
                              function_ref<InFlightDiagnostic()> emitError);
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index c04986cad84f9..a5b88338e6381 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -40,9 +40,9 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
       return failure();
 
     // Create new CallOp.
-    auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
-                                         adaptor.getFlatOperands());
-    newOp->setAttrs(op->getAttrs());
+    auto newOp =
+        rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
+                                adaptor.getFlatOperands(), op->getAttrs());
 
     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
     return success();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9b..47bd56a7c01e7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1808,11 +1808,11 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
@@ -2486,9 +2486,9 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
   auto srcType = llvm::cast<MemRefType>(src.getType());
   MemRefType resultType =
       CollapseShapeOp::computeCollapsedType(srcType, reassociation);
-  build(b, result, resultType, src, attrs);
   result.addAttribute(::mlir::getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
+  build(b, result, resultType, src, attrs);
 }
 
 LogicalResult CollapseShapeOp::verify() {
@@ -2784,11 +2784,11 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
     resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
         sourceMemRefType, staticOffsets, staticSizes, staticStrides));
   }
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 // Build a SubViewOp with mixed static and dynamic entries and inferred result
@@ -3323,8 +3323,8 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
   // Compute result type.
   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
 
-  build(b, result, resultType, in, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
+  build(b, result, resultType, in, attrs);
 }
 
 // transpose $in $permutation attr-dict : type($in) `to` type(results)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d1355..2c1d43d002664 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1743,9 +1743,9 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
       llvm::cast<RankedTensorType>(src.getType()),
       getSymbolLessAffineMaps(
           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
-  build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrStrName(),
                       getReassociationIndicesAttribute(b, reassociation));
+  build(b, result, resultType, src, attrs);
 }
 
 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
@@ -2099,11 +2099,11 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
     resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
         sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
   }
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
@@ -2498,11 +2498,11 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
@@ -2943,10 +2943,10 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
   auto sourceType = llvm::cast<RankedTensorType>(source.getType());
   if (!resultType)
     resultType = inferResultType(sourceType, staticLow, staticHigh);
+  result.addAttributes(attrs);
   build(b, result, resultType, source, low, high,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
-  result.addAttributes(attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2976,10 +2976,10 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
     resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
   }
   assert(llvm::isa<RankedTensorType>(resultType));
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
-  result.addAttributes(attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3423,11 +3423,11 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c687..115b22986fab8 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2113,23 +2113,23 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
 
 func.func @omp_distribute_wrapper() -> () {
   // expected-error @below {{op must be a loop wrapper}}
-  "omp.distribute"() ({
+  omp.distribute {
       %0 = arith.constant 0 : i32
       "omp.terminator"() : () -> ()
-    }) : () -> ()
+  }
 }
 
 // -----
 
 func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
-  "omp.distribute"() ({
-      "omp.wsloop"() ({
-        %0 = arith.constant 0 : i32
-        "omp.terminator"() : () -> ()
-      }) : () -> ()
+  omp.distribute {
+    "omp.wsloop"() ({
+      %0 = arith.constant 0 : i32
       "omp.terminator"() : () -> ()
     }) : () -> ()
+    "omp.terminator"() : () -> ()
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b7..3032914186d90 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -521,13 +521,13 @@ func.func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, %data_var
 // CHECK-LABEL: omp_simd
 func.func @omp_simd(%lb : index, %ub : index, %step : index) -> () {
   // CHECK: omp.simd
-  "omp.simd" () ({
+  omp.simd {
     "omp.loop_nest" (%lb, %ub, %step) ({
     ^bb1(%iv2: index):
       "omp.yield"() : () -> ()
     }) : (index, index, index) -> ()
     "omp.terminator"() : () -> ()
-  }) : () -> ()
+  }
 
   return
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac394..52fa0f69dbb4c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2401,6 +2401,13 @@ def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
   let regions = (region AnyRegion:$body);
 }
 
+// Two variadic args, non variadic results, with AttrSizedOperandSegments
+// Test build method generation for property conversion & type inference.
+def TableGenBuildOp6 : TEST_Op<"tblgen_build_6", [AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$a, Variadic<AnyType>:$b);
+  let results = (outs F32:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test BufferPlacement
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 63fe5a8099074..e013ccac5dd0f 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1311,22 +1311,24 @@ void OpEmitter::genPropertiesSupport() {
     return ::mlir::failure();
   }
     )decl";
-  // TODO: properties might be optional as well.
-  const char *propFromAttrFmt = R"decl(;
-    {{
+  const char *propFromAttrFmt = R"decl(
       auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
                ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
-        {0};
+        {0}
       };
       {2};
-      if (!attr) {{
-        emitError() << "expected key entry for {1} in DictionaryAttr to set "
-                   "Properties.";
+)decl";
+  const char *attrGetNoDefaultFmt = R"decl(;
+      if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
         return ::mlir::failure();
+)decl";
+  const char *attrGetDefaultFmt = R"decl(;
+      if (attr) {{
+        if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
+          return ::mlir::failure();
+      } else {{
+        prop.{0} = {1};
       }
-      if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
-        return ::mlir::failure();
-    }
 )decl";
 
   for (const auto &attrOrProp : attrOrProperties) {
@@ -1349,13 +1351,20 @@ void OpEmitter::genPropertiesSupport() {
       }
       os.flush();
 
-      setPropMethod << formatv(propFromAttrFmt,
+      setPropMethod << "{\n"
+                    << formatv(propFromAttrFmt,
                                tgfmt(prop.getConvertFromAttributeCall(),
                                      &fctx.addSubst("_attr", propertyAttr)
                                           .addSubst("_storage", propertyStorage)
                                           .addSubst("_diag", propertyDiag)),
                                name, getAttr);
-
+      if (prop.hasDefaultValue()) {
+        setPropMethod << formatv(attrGetDefaultFmt, name,
+                                 prop.getDefaultValue());
+      } else {
+        setPropMethod << formatv(attrGetNoDefaultFmt, name);
+      }
+      setPropMethod << "  }\n";
     } else {
       const auto *namedAttr =
           llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
@@ -1376,13 +1385,8 @@ void OpEmitter::genPropertiesSupport() {
       setPropMethod << formatv(R"decl(
   {{
     auto &propStorage = prop.{0};
-    {2}
-    if (attr || /*isRequired=*/{1}) {{
-      if (!attr) {{
-        emitError() << "expected key entry for {0} in DictionaryAttr to set "
-                   "Properties.";
-        return ::mlir::failure();
-      }
+    {1}
+    if (attr) {{
       auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
       if (convertedAttr) {{
         propStorage = convertedAttr;
@@ -1393,7 +1397,7 @@ void OpEmitter::genPropertiesSupport() {
     }
   }
 )decl",
-                               name, namedAttr->isRequired, getAttr);
+                               name, getAttr);
     }
   }
   setPropMethod << "  return ::mlir::success();\n";
@@ -2650,6 +2654,21 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   }
 
   // Result types
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  if (!attributes.empty()) {
+    ::mlir::OpaqueProperties properties =
+      &{1}.getOrAddProperties<{0}::Properties>();
+    std::optional<::mlir::RegisteredOperationName> info =
+      {1}.name.getRegisteredInfo();
+    if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+        {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+      ::llvm::report_fatal_error("Property conversion failed.");
+  })",
+                    opClass.getClassName(), builderOpState);
+  }
   body << formatv(R"(
   ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
@@ -2879,6 +2898,22 @@ void OpEmitter::genCollectiveParamBuilder() {
          << "u && \"mismatched number of return types\");\n";
   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
 
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  if (!attributes.empty()) {
+    ::mlir::OpaqueProperties properties =
+      &{1}.getOrAddProperties<{0}::Properties>();
+    std::optional<::mlir::RegisteredOperationName> info =
+      {1}.name.getRegisteredInfo();
+    if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+        {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+      ::llvm::report_fatal_error("Property conversion failed.");
+  })",
+                    opClass.getClassName(), builderOpState);
+  }
+
   // Generate builder that infers type too.
   // TODO: Expand to handle successors.
   if (canInferType(op) && op.getNumSuccessors() == 0)
@@ -4054,13 +4089,17 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
   {
     SmallVector<MethodParameter> paramList;
-    paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
-                           attrSizedOperands ? "" : "nullptr");
-    if (useProperties)
-      paramList.emplace_back("const Properties &", "properties", "{}");
-    else
+    if (useProperties) {
+      // Properties can't be given a default constructor here due to Properties
+      // struct being defined in the enclosing class which isn't complete by
+      // here.
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs");
+      paramList.emplace_back("const Properties &", "properties");
+    } else {
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs", "{}");
       paramList.emplace_back("const ::mlir::EmptyProperties &", "properties",
                              "{}");
+    }
     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
     auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
     baseConstructor->addMemberInitializer("odsAttrs", "attrs");
@@ -4102,6 +4141,21 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
           "::mlir::EmptyProperties{}), "
           "regions");
     }
+
+    // Add forwarding constructor that constructs Properties.
+    if (useProperties) {
+      SmallVector<MethodParameter> paramList;
+      paramList.emplace_back("RangeT", "values");
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
+                             attrSizedOperands ? "" : "nullptr");
+      auto *noPropertiesConstructor =
+          genericAdaptor.addConstructor(std::move(paramList));
+      noPropertiesConstructor->addMemberInitializer(
+          genericAdaptor.getClassName(), "values, "
+                                         "attrs, "
+                                         "Properties{}, "
+                                         "{}");
+    }
   }
 
   // Create constructors constructing the adaptor from an instance of the op.
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index c83ac9088114c..94fbfa28803c4 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -66,29 +66,44 @@ class OpBuildGenTest : public ::testing::Test {
       EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
                 attrs[idx].getValue());
 
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
     concreteOp.erase();
   }
 
-  // Helper method to test ops with inferred result types and single variadic
-  // input.
   template <typename OpTy>
-  void testSingleVariadicInputInferredType() {
-    // Test separate arg, separate param build method.
-    auto op = builder.create<OpTy>(loc, i32Ty, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test collective params build method.
-    op = builder.create<OpTy>(loc, TypeRange{i32Ty},
-                              ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types, default value of attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types and supplied attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32}, attrs);
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
+  void verifyOp(OpTy &&concreteOp, std::vector<Type> resultTypes,
+                std::vector<Value> operands1, std::vector<Value> operands2,
+                std::vector<NamedAttribute> attrs) {
+    ASSERT_NE(concreteOp, nullptr);
+    Operation *op = concreteOp.getOperation();
+
+    EXPECT_EQ(op->getNumResults(), resultTypes.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumResults()))
+      EXPECT_EQ(op->getResult(idx).getType(), resultTypes[idx]);
+
+    auto operands = llvm::to_vector(llvm::concat<Value>(operands1, operands2));
+    EXPECT_EQ(op->getNumOperands(), operands.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumOperands()))
+      EXPECT_EQ(op->getOperand(idx), operands[idx]);
+
+    EXPECT_EQ(op->getAttrs().size(), attrs.size());
+    if (op->getAttrs().size() != attrs.size()) {
+      // Simple export where there is mismatch count.
+      llvm::errs() << "Op attrs:\n";
+      for (auto it : op->getAttrs())
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+
+      llvm::errs() << "Expected attrs:\n";
+      for (auto it : attrs)
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+    } else {
+      for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size()))
+        EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
+                  attrs[idx].getValue());
+    }
+
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
+    concreteOp.erase();
   }
 
 protected:
@@ -205,13 +220,31 @@ TEST_F(OpBuildGenTest,
   verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
 }
 
-// The next test checks supression of ambiguous build methods for ops that
+// The next test checks suppression of ambiguous build methods for ops that
 // have a single variadic input, and single non-variadic result, and which
-// support the SameOperandsAndResultType trait and and optionally the
+// support the SameOperandsAndResultType trait and optionally the
 // InferOpTypeInterface interface. For such ops, the ODS framework generates
 // build methods with no result types as they are inferred from the input types.
 TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
-  testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp4>(
+      loc, i32Ty, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp4>(loc, TypeRange{i32Ty},
+                                              ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op =
+      builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
 }
 
 TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
@@ -221,4 +254,41 @@ TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
   verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);
 }
 
+TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) {
+  // Account for conversion as part of getAttrs().
+  std::vector<NamedAttribute> noAttrsStorage;
+  auto segmentSize = builder.getNamedAttr("operandSegmentSizes",
+                                          builder.getDenseI32ArrayAttr({1, 1}));
+  noAttrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> noAttrs(noAttrsStorage);
+  std::vector<NamedAttribute> attrsStorage = this->attrStorage;
+  attrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> attrs(attrsStorage);
+
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp6>(
+      loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32},
+                                              ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with result types, supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32, *cstI32}, attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+}
+
 } // namespace mlir



More information about the Mlir-commits mailing list