[Mlir-commits] [mlir] [mlir][ods] Populate properties in generated builder (PR #90430)

Jacques Pienaar llvmlistbot at llvm.org
Thu May 9 21:07:02 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/90430

>From f0c82919af51ca7fbcf43b57c19da61348c3d470 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Fri, 10 May 2024 04:05:45 +0000
Subject: [PATCH] [mlir][ods] Populate properties in generated builder

Previously this was only populated in the create method later. This resolves
some of invalid builder paths. This may also be sufficient that type inference
functions no longer have to consider whether property conversion has happened
(but haven't verified that yet).

This also makes Attributes corresponding to Properties
as optional inside the set from attributes method. Today that is in effect what happens with Property value initialization and folks use it to define custom C++ types whose default initialization is what they want. This is the behavior users get if they use properties directly. Propagating Attributes without allowing partial setting would require iterating over the dictionary attribute considering the properties of the op type that will be created. This could also have been an additional method generated or optional behavior on the set method. But doing it consistently seems better. In terms of whats lost, it doesn't seem like anything compared to the pure Property path where Property is default value initialized and then partially overwritten (this doesn't seem to buy anything else verification wise).

Default valued Properties (as specified ODS side rather than C++ side) triggered error as the containing class was not yet complete but referenced nested class, so that we couldn't have default initializer for them in the parent class. Added an additional forwarding builder to avoid needing to update call sites. This could be split out to separate change.

Inlined templated function in unit test that was only used 1x. Moved initialization earlier where seen.
---
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   |   4 +-
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |   2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |   2 +-
 mlir/include/mlir/IR/Operation.h              |   5 +-
 .../Func/Transforms/OneToNFuncConversions.cpp |   6 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |   8 +-
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      |  12 +-
 mlir/test/Dialect/OpenMP/invalid.mlir         |  14 +-
 mlir/test/Dialect/OpenMP/ops.mlir             |   4 +-
 mlir/test/lib/Dialect/Test/TestOps.td         |   7 +
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 122 ++++++++++++++----
 mlir/unittests/TableGen/OpBuildGen.cpp        | 114 ++++++++++++----
 12 files changed, 224 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 6655ce6f123e1..7f394972a697b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -65,13 +65,13 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
   let builders = [
     OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
-      build($_builder, $_state, type, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+      build($_builder, $_state, type, lhs, rhs);
     }]>,
     OpBuilder<(ins "Value":$lhs, "Value":$rhs,
                    "IntegerOverflowFlags":$overflowFlags), [{
-      build($_builder, $_state, lhs, rhs);
       $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+      build($_builder, $_state, lhs, rhs);
     }]>
   ];
 
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 5738b6ca51c12..63e6ed059deb1 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1764,9 +1764,9 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
       "ArrayRef<ReassociationIndices>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      build($_builder, $_state, resultType, src, attrs);
       $_state.addAttribute("reassociation",
                           getReassociationIndicesAttribute($_builder, reassociation));
+      build($_builder, $_state, resultType, src, attrs);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index a403e89a39f98..cafc3d91fd1e9 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1216,9 +1216,9 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
       "ArrayRef<ReassociationIndices>":$reassociation,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
-      build($_builder, $_state, resultType, src, attrs);
       $_state.addAttribute("reassociation",
           getReassociationIndicesAttribute($_builder, reassociation));
+      build($_builder, $_state, resultType, src, attrs);
     }]>,
     OpBuilder<(ins "Type":$resultType, "Value":$src,
       "ArrayRef<ReassociationExprs>":$reassociation,
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index c52a6fcac10c1..f0dd7c5178056 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -916,11 +916,12 @@ class alignas(8) Operation final
   /// operation. Returns an empty attribute if no properties are present.
   Attribute getPropertiesAsAttribute();
 
-  /// Set the properties from the provided  attribute.
+  /// Set the properties from the provided attribute.
   /// This is an expensive operation that can fail if the attribute is not
   /// matching the expectations of the properties for this operation. This is
   /// mostly useful for unregistered operations or used when parsing the
-  /// generic format. An optional diagnostic can be passed in for richer errors.
+  /// generic format. An optional diagnostic emitter can be passed in for richer
+  /// errors, if none is passed then behavior is undefined in error case.
   LogicalResult
   setPropertiesFromAttribute(Attribute attr,
                              function_ref<InFlightDiagnostic()> emitError);
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index c04986cad84f9..a5b88338e6381 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -40,9 +40,9 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
       return failure();
 
     // Create new CallOp.
-    auto newOp = rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
-                                         adaptor.getFlatOperands());
-    newOp->setAttrs(op->getAttrs());
+    auto newOp =
+        rewriter.create<CallOp>(loc, resultMapping.getConvertedTypes(),
+                                adaptor.getFlatOperands(), op->getAttrs());
 
     rewriter.replaceOp(op, newOp->getResults(), resultMapping);
     return success();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 78201ae29cd9b..47bd56a7c01e7 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1808,11 +1808,11 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
@@ -2486,9 +2486,9 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
   auto srcType = llvm::cast<MemRefType>(src.getType());
   MemRefType resultType =
       CollapseShapeOp::computeCollapsedType(srcType, reassociation);
-  build(b, result, resultType, src, attrs);
   result.addAttribute(::mlir::getReassociationAttrName(),
                       getReassociationIndicesAttribute(b, reassociation));
+  build(b, result, resultType, src, attrs);
 }
 
 LogicalResult CollapseShapeOp::verify() {
@@ -2784,11 +2784,11 @@ void SubViewOp::build(OpBuilder &b, OperationState &result,
     resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
         sourceMemRefType, staticOffsets, staticSizes, staticStrides));
   }
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 // Build a SubViewOp with mixed static and dynamic entries and inferred result
@@ -3323,8 +3323,8 @@ void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
   // Compute result type.
   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
 
-  build(b, result, resultType, in, attrs);
   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
+  build(b, result, resultType, in, attrs);
 }
 
 // transpose $in $permutation attr-dict : type($in) `to` type(results)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 7a13f7a7d1355..2c1d43d002664 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1743,9 +1743,9 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
       llvm::cast<RankedTensorType>(src.getType()),
       getSymbolLessAffineMaps(
           convertReassociationIndicesToExprs(b.getContext(), reassociation)));
-  build(b, result, resultType, src, attrs);
   result.addAttribute(getReassociationAttrStrName(),
                       getReassociationIndicesAttribute(b, reassociation));
+  build(b, result, resultType, src, attrs);
 }
 
 template <typename TensorReshapeOp, bool isExpansion = std::is_same<
@@ -2099,11 +2099,11 @@ void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
     resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
         sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
   }
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
@@ -2498,11 +2498,11 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an InsertSliceOp with mixed static and dynamic entries packed into a
@@ -2943,10 +2943,10 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
   auto sourceType = llvm::cast<RankedTensorType>(source.getType());
   if (!resultType)
     resultType = inferResultType(sourceType, staticLow, staticHigh);
+  result.addAttributes(attrs);
   build(b, result, resultType, source, low, high,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
-  result.addAttributes(attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -2976,10 +2976,10 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
     resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
   }
   assert(llvm::isa<RankedTensorType>(resultType));
+  result.addAttributes(attrs);
   build(b, result, resultType, source, dynamicLow, dynamicHigh,
         b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
         nofold ? b.getUnitAttr() : UnitAttr());
-  result.addAttributes(attrs);
 }
 
 void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
@@ -3423,11 +3423,11 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
+  result.addAttributes(attrs);
   build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
         b.getDenseI64ArrayAttr(staticSizes),
         b.getDenseI64ArrayAttr(staticStrides));
-  result.addAttributes(attrs);
 }
 
 /// Build an ParallelInsertSliceOp with mixed static and dynamic entries
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c687..115b22986fab8 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2113,23 +2113,23 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
 
 func.func @omp_distribute_wrapper() -> () {
   // expected-error @below {{op must be a loop wrapper}}
-  "omp.distribute"() ({
+  omp.distribute {
       %0 = arith.constant 0 : i32
       "omp.terminator"() : () -> ()
-    }) : () -> ()
+  }
 }
 
 // -----
 
 func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
   // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
-  "omp.distribute"() ({
-      "omp.wsloop"() ({
-        %0 = arith.constant 0 : i32
-        "omp.terminator"() : () -> ()
-      }) : () -> ()
+  omp.distribute {
+    "omp.wsloop"() ({
+      %0 = arith.constant 0 : i32
       "omp.terminator"() : () -> ()
     }) : () -> ()
+    "omp.terminator"() : () -> ()
+  }
 }
 
 // -----
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b7..3032914186d90 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -521,13 +521,13 @@ func.func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index, %data_var
 // CHECK-LABEL: omp_simd
 func.func @omp_simd(%lb : index, %ub : index, %step : index) -> () {
   // CHECK: omp.simd
-  "omp.simd" () ({
+  omp.simd {
     "omp.loop_nest" (%lb, %ub, %step) ({
     ^bb1(%iv2: index):
       "omp.yield"() : () -> ()
     }) : (index, index, index) -> ()
     "omp.terminator"() : () -> ()
-  }) : () -> ()
+  }
 
   return
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac394..52fa0f69dbb4c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2401,6 +2401,13 @@ def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
   let regions = (region AnyRegion:$body);
 }
 
+// Two variadic args, non variadic results, with AttrSizedOperandSegments
+// Test build method generation for property conversion & type inference.
+def TableGenBuildOp6 : TEST_Op<"tblgen_build_6", [AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$a, Variadic<AnyType>:$b);
+  let results = (outs F32:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test BufferPlacement
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 63fe5a8099074..c52ca1e4f51ac 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -41,6 +41,14 @@
 
 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
 
+#if 0
+#define DBG_ODS_PRINT(body, X)                                                 \
+  body << "fprintf(stderr, \"Generated from " << X                             \
+       << " at %s:%d\\n\", __FILE__, __LINE__);\n";
+#else
+#define DBG_ODS_PRINT(body, X)
+#endif
+
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
@@ -464,7 +472,7 @@ void OpOrAdaptorHelper::computeAttrMetadata() {
         /*hashPropertyCall=*/
         "::llvm::hash_combine_range(std::begin($_storage), "
         "std::end($_storage));",
-        /*StringRef defaultValue=*/"");
+        /*StringRef defaultValue=*/"{}");
   };
   // Include key attributes from several traits as implicitly registered.
   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
@@ -1311,22 +1319,26 @@ void OpEmitter::genPropertiesSupport() {
     return ::mlir::failure();
   }
     )decl";
-  // TODO: properties might be optional as well.
-  const char *propFromAttrFmt = R"decl(;
-    {{
+  // TODO: properties might be optional as well. The below handles default ones
+  // as optional for the sake of setting attributes.
+  const char *propFromAttrFmt = R"decl(
       auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
                ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
-        {0};
+        {0}
       };
       {2};
-      if (!attr) {{
-        emitError() << "expected key entry for {1} in DictionaryAttr to set "
-                   "Properties.";
+)decl";
+  const char *attrGetNoDefaultFmt = R"decl(;
+      if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
         return ::mlir::failure();
+)decl";
+  const char *attrGetDefaultFmt = R"decl(;
+      if (attr) {{
+        if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
+          return ::mlir::failure();
+      } else {{
+        prop.{0} = {1};
       }
-      if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
-        return ::mlir::failure();
-    }
 )decl";
 
   for (const auto &attrOrProp : attrOrProperties) {
@@ -1349,13 +1361,20 @@ void OpEmitter::genPropertiesSupport() {
       }
       os.flush();
 
-      setPropMethod << formatv(propFromAttrFmt,
+      setPropMethod << "{\n"
+                    << formatv(propFromAttrFmt,
                                tgfmt(prop.getConvertFromAttributeCall(),
                                      &fctx.addSubst("_attr", propertyAttr)
                                           .addSubst("_storage", propertyStorage)
                                           .addSubst("_diag", propertyDiag)),
                                name, getAttr);
-
+      if (prop.hasDefaultValue()) {
+        setPropMethod << formatv(attrGetDefaultFmt, name,
+                                 prop.getDefaultValue());
+      } else {
+        setPropMethod << formatv(attrGetNoDefaultFmt, name);
+      }
+      setPropMethod << "  }\n";
     } else {
       const auto *namedAttr =
           llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
@@ -1376,13 +1395,8 @@ void OpEmitter::genPropertiesSupport() {
       setPropMethod << formatv(R"decl(
   {{
     auto &propStorage = prop.{0};
-    {2}
-    if (attr || /*isRequired=*/{1}) {{
-      if (!attr) {{
-        emitError() << "expected key entry for {0} in DictionaryAttr to set "
-                   "Properties.";
-        return ::mlir::failure();
-      }
+    {1}
+    if (attr) {{
       auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
       if (convertedAttr) {{
         propStorage = convertedAttr;
@@ -1393,7 +1407,7 @@ void OpEmitter::genPropertiesSupport() {
     }
   }
 )decl",
-                               name, namedAttr->isRequired, getAttr);
+                               name, getAttr);
     }
   }
   setPropMethod << "  return ::mlir::success();\n";
@@ -2397,6 +2411,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
     if (!m)
       return;
     auto &body = m->body();
+    DBG_ODS_PRINT(body, __LINE__);
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
                                                AttrParamKind::UnwrappedValue);
@@ -2519,6 +2534,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Operands
   body << "  " << builderOpState << ".addOperands(operands);\n";
@@ -2623,6 +2639,7 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariableLengthResults();
@@ -2650,6 +2667,21 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   }
 
   // Result types
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  if (!attributes.empty()) {
+    ::mlir::OpaqueProperties properties =
+      &{1}.getOrAddProperties<{0}::Properties>();
+    std::optional<::mlir::RegisteredOperationName> info =
+      {1}.name.getRegisteredInfo();
+    if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+        {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+      ::llvm::report_fatal_error("Property conversion failed.");
+  })",
+                    opClass.getClassName(), builderOpState);
+  }
   body << formatv(R"(
   ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
@@ -2684,6 +2716,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
     if (!m)
       return;
     auto &body = m->body();
+    DBG_ODS_PRINT(body, __LINE__);
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
                                                AttrParamKind::UnwrappedValue);
@@ -2721,6 +2754,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
     return;
 
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Push all result types to the operation state
   std::string resultType;
@@ -2852,6 +2886,7 @@ void OpEmitter::genCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Operands
   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
@@ -2879,6 +2914,22 @@ void OpEmitter::genCollectiveParamBuilder() {
          << "u && \"mismatched number of return types\");\n";
   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
 
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  if (!attributes.empty()) {
+    ::mlir::OpaqueProperties properties =
+      &{1}.getOrAddProperties<{0}::Properties>();
+    std::optional<::mlir::RegisteredOperationName> info =
+      {1}.name.getRegisteredInfo();
+    if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+        {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+      ::llvm::report_fatal_error("Property conversion failed.");
+  })",
+                    opClass.getClassName(), builderOpState);
+  }
+
   // Generate builder that infers type too.
   // TODO: Expand to handle successors.
   if (canInferType(op) && op.getNumSuccessors() == 0)
@@ -4054,13 +4105,17 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
   {
     SmallVector<MethodParameter> paramList;
-    paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
-                           attrSizedOperands ? "" : "nullptr");
-    if (useProperties)
-      paramList.emplace_back("const Properties &", "properties", "{}");
-    else
+    if (useProperties) {
+      // Properties can't be given a default constructor here due to Properties
+      // struct being defined in the enclosing class which isn't complete by
+      // here.
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs");
+      paramList.emplace_back("const Properties &", "properties");
+    } else {
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs", "{}");
       paramList.emplace_back("const ::mlir::EmptyProperties &", "properties",
                              "{}");
+    }
     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
     auto *baseConstructor = genericAdaptorBase.addConstructor(paramList);
     baseConstructor->addMemberInitializer("odsAttrs", "attrs");
@@ -4102,6 +4157,21 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
           "::mlir::EmptyProperties{}), "
           "regions");
     }
+
+    // Add forwarding constructor that constructs Properties.
+    if (useProperties) {
+      SmallVector<MethodParameter> paramList;
+      paramList.emplace_back("RangeT", "values");
+      paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
+                             attrSizedOperands ? "" : "nullptr");
+      auto *noPropertiesConstructor =
+          genericAdaptor.addConstructor(std::move(paramList));
+      noPropertiesConstructor->addMemberInitializer(
+          genericAdaptor.getClassName(), "values, "
+                                         "attrs, "
+                                         "Properties{}, "
+                                         "{}");
+    }
   }
 
   // Create constructors constructing the adaptor from an instance of the op.
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index c83ac9088114c..94fbfa28803c4 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -66,29 +66,44 @@ class OpBuildGenTest : public ::testing::Test {
       EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
                 attrs[idx].getValue());
 
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
     concreteOp.erase();
   }
 
-  // Helper method to test ops with inferred result types and single variadic
-  // input.
   template <typename OpTy>
-  void testSingleVariadicInputInferredType() {
-    // Test separate arg, separate param build method.
-    auto op = builder.create<OpTy>(loc, i32Ty, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test collective params build method.
-    op = builder.create<OpTy>(loc, TypeRange{i32Ty},
-                              ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types, default value of attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types and supplied attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32}, attrs);
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
+  void verifyOp(OpTy &&concreteOp, std::vector<Type> resultTypes,
+                std::vector<Value> operands1, std::vector<Value> operands2,
+                std::vector<NamedAttribute> attrs) {
+    ASSERT_NE(concreteOp, nullptr);
+    Operation *op = concreteOp.getOperation();
+
+    EXPECT_EQ(op->getNumResults(), resultTypes.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumResults()))
+      EXPECT_EQ(op->getResult(idx).getType(), resultTypes[idx]);
+
+    auto operands = llvm::to_vector(llvm::concat<Value>(operands1, operands2));
+    EXPECT_EQ(op->getNumOperands(), operands.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumOperands()))
+      EXPECT_EQ(op->getOperand(idx), operands[idx]);
+
+    EXPECT_EQ(op->getAttrs().size(), attrs.size());
+    if (op->getAttrs().size() != attrs.size()) {
+      // Simple export where there is mismatch count.
+      llvm::errs() << "Op attrs:\n";
+      for (auto it : op->getAttrs())
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+
+      llvm::errs() << "Expected attrs:\n";
+      for (auto it : attrs)
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+    } else {
+      for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size()))
+        EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
+                  attrs[idx].getValue());
+    }
+
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
+    concreteOp.erase();
   }
 
 protected:
@@ -205,13 +220,31 @@ TEST_F(OpBuildGenTest,
   verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
 }
 
-// The next test checks supression of ambiguous build methods for ops that
+// The next test checks suppression of ambiguous build methods for ops that
 // have a single variadic input, and single non-variadic result, and which
-// support the SameOperandsAndResultType trait and and optionally the
+// support the SameOperandsAndResultType trait and optionally the
 // InferOpTypeInterface interface. For such ops, the ODS framework generates
 // build methods with no result types as they are inferred from the input types.
 TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
-  testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp4>(
+      loc, i32Ty, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp4>(loc, TypeRange{i32Ty},
+                                              ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op =
+      builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
 }
 
 TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
@@ -221,4 +254,41 @@ TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
   verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);
 }
 
+TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) {
+  // Account for conversion as part of getAttrs().
+  std::vector<NamedAttribute> noAttrsStorage;
+  auto segmentSize = builder.getNamedAttr("operandSegmentSizes",
+                                          builder.getDenseI32ArrayAttr({1, 1}));
+  noAttrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> noAttrs(noAttrsStorage);
+  std::vector<NamedAttribute> attrsStorage = this->attrStorage;
+  attrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> attrs(attrsStorage);
+
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp6>(
+      loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32},
+                                              ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with result types, supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32, *cstI32}, attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+}
+
 } // namespace mlir



More information about the Mlir-commits mailing list