[Mlir-commits] [mlir] [mlir][ods] Populate properties in generated builder (PR #90430)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 28 21:12:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Previously this was only populated in the create method later. Expand test so that it covers the generated builders.

Added a little emission helper in for when constructing the test (can also drop it, but might want more of these/formalize mode so that tracking it easier).

Also changed some error messages to match error reporting style. Inlined templated function that was only one in one spot.

---
Full diff: https://github.com/llvm/llvm-project/pull/90430.diff


3 Files Affected:

- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+7) 
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+44-3) 
- (modified) mlir/unittests/TableGen/OpBuildGen.cpp (+92-22) 


``````````diff
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 5352d574ac3943..52fa0f69dbb4c4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2401,6 +2401,13 @@ def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
   let regions = (region AnyRegion:$body);
 }
 
+// Two variadic args, non variadic results, with AttrSizedOperandSegments
+// Test build method generation for property conversion & type inference.
+def TableGenBuildOp6 : TEST_Op<"tblgen_build_6", [AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$a, Variadic<AnyType>:$b);
+  let results = (outs F32:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test BufferPlacement
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 53ed5cb7c043ec..78c53202187bc0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -41,6 +41,14 @@
 
 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
 
+#if 0
+#define DBG_ODS_PRINT(body, X)                                                 \
+  body << "fprintf(stderr, \"Generated from " << X                             \
+       << " at %s:%d\\n\", __FILE__, __LINE__);\n";
+#else
+#define DBG_ODS_PRINT(body, X)
+#endif
+
 using namespace llvm;
 using namespace mlir;
 using namespace mlir::tblgen;
@@ -1321,7 +1329,7 @@ void OpEmitter::genPropertiesSupport() {
       {2};
       if (!attr) {{
         emitError() << "expected key entry for {1} in DictionaryAttr to set "
-                   "Properties.";
+                   "Properties";
         return ::mlir::failure();
       }
       if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
@@ -1380,14 +1388,14 @@ void OpEmitter::genPropertiesSupport() {
     if (attr || /*isRequired=*/{1}) {{
       if (!attr) {{
         emitError() << "expected key entry for {0} in DictionaryAttr to set "
-                   "Properties.";
+                   "Properties";
         return ::mlir::failure();
       }
       auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
       if (convertedAttr) {{
         propStorage = convertedAttr;
       } else {{
-        emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
+        emitError() << "invalid attribute `{0}` in property conversion: " << attr;
         return ::mlir::failure();
       }
     }
@@ -2397,6 +2405,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
     if (!m)
       return;
     auto &body = m->body();
+    DBG_ODS_PRINT(body, __LINE__);
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
                                                AttrParamKind::UnwrappedValue);
@@ -2519,6 +2528,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Operands
   body << "  " << builderOpState << ".addOperands(operands);\n";
@@ -2623,6 +2633,7 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   int numResults = op.getNumResults();
   int numVariadicResults = op.getNumVariableLengthResults();
@@ -2650,6 +2661,19 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
   }
 
   // Result types
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  ::mlir::OpaqueProperties properties =
+    &{1}.getOrAddProperties<{0}::Properties>();
+  std::optional<::mlir::RegisteredOperationName> info =
+    {1}.name.getRegisteredInfo();
+  if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+      {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+    ::llvm::report_fatal_error("Property conversion failed.");)",
+                    opClass.getClassName(), builderOpState);
+  }
   body << formatv(R"(
   ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
   if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
@@ -2684,6 +2708,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
     if (!m)
       return;
     auto &body = m->body();
+    DBG_ODS_PRINT(body, __LINE__);
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
                                                AttrParamKind::UnwrappedValue);
@@ -2721,6 +2746,7 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
     return;
 
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Push all result types to the operation state
   std::string resultType;
@@ -2852,6 +2878,7 @@ void OpEmitter::genCollectiveParamBuilder() {
   if (!m)
     return;
   auto &body = m->body();
+  DBG_ODS_PRINT(body, __LINE__);
 
   // Operands
   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
@@ -2879,6 +2906,20 @@ void OpEmitter::genCollectiveParamBuilder() {
          << "u && \"mismatched number of return types\");\n";
   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
 
+  if (emitHelper.hasProperties()) {
+    // Initialize the properties from Attributes before invoking the infer
+    // function.
+    body << formatv(R"(
+  ::mlir::OpaqueProperties properties =
+    &{1}.getOrAddProperties<{0}::Properties>();
+  std::optional<::mlir::RegisteredOperationName> info =
+    {1}.name.getRegisteredInfo();
+  if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
+      {1}.attributes.getDictionary({1}.getContext()), nullptr)))
+    ::llvm::report_fatal_error("Property conversion failed.");)",
+                    opClass.getClassName(), builderOpState);
+  }
+
   // Generate builder that infers type too.
   // TODO: Expand to handle successors.
   if (canInferType(op) && op.getNumSuccessors() == 0)
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index c83ac9088114ce..94fbfa28803c48 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -66,29 +66,44 @@ class OpBuildGenTest : public ::testing::Test {
       EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
                 attrs[idx].getValue());
 
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
     concreteOp.erase();
   }
 
-  // Helper method to test ops with inferred result types and single variadic
-  // input.
   template <typename OpTy>
-  void testSingleVariadicInputInferredType() {
-    // Test separate arg, separate param build method.
-    auto op = builder.create<OpTy>(loc, i32Ty, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test collective params build method.
-    op = builder.create<OpTy>(loc, TypeRange{i32Ty},
-                              ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types, default value of attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32});
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
-
-    // Test build method with no result types and supplied attributes.
-    op = builder.create<OpTy>(loc, ValueRange{*cstI32, *cstI32}, attrs);
-    verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
+  void verifyOp(OpTy &&concreteOp, std::vector<Type> resultTypes,
+                std::vector<Value> operands1, std::vector<Value> operands2,
+                std::vector<NamedAttribute> attrs) {
+    ASSERT_NE(concreteOp, nullptr);
+    Operation *op = concreteOp.getOperation();
+
+    EXPECT_EQ(op->getNumResults(), resultTypes.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumResults()))
+      EXPECT_EQ(op->getResult(idx).getType(), resultTypes[idx]);
+
+    auto operands = llvm::to_vector(llvm::concat<Value>(operands1, operands2));
+    EXPECT_EQ(op->getNumOperands(), operands.size());
+    for (unsigned idx : llvm::seq(0U, op->getNumOperands()))
+      EXPECT_EQ(op->getOperand(idx), operands[idx]);
+
+    EXPECT_EQ(op->getAttrs().size(), attrs.size());
+    if (op->getAttrs().size() != attrs.size()) {
+      // Simple export where there is mismatch count.
+      llvm::errs() << "Op attrs:\n";
+      for (auto it : op->getAttrs())
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+
+      llvm::errs() << "Expected attrs:\n";
+      for (auto it : attrs)
+        llvm::errs() << "\t" << it.getName() << " = " << it.getValue() << "\n";
+    } else {
+      for (unsigned idx : llvm::seq<unsigned>(0U, attrs.size()))
+        EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()),
+                  attrs[idx].getValue());
+    }
+
+    EXPECT_TRUE(mlir::succeeded(concreteOp.verify()));
+    concreteOp.erase();
   }
 
 protected:
@@ -205,13 +220,31 @@ TEST_F(OpBuildGenTest,
   verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
 }
 
-// The next test checks supression of ambiguous build methods for ops that
+// The next test checks suppression of ambiguous build methods for ops that
 // have a single variadic input, and single non-variadic result, and which
-// support the SameOperandsAndResultType trait and and optionally the
+// support the SameOperandsAndResultType trait and optionally the
 // InferOpTypeInterface interface. For such ops, the ODS framework generates
 // build methods with no result types as they are inferred from the input types.
 TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
-  testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp4>(
+      loc, i32Ty, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp4>(loc, TypeRange{i32Ty},
+                                              ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op =
+      builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32});
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, noAttrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp4>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {i32Ty}, {*cstI32, *cstI32}, attrs);
 }
 
 TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
@@ -221,4 +254,41 @@ TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
   verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);
 }
 
+TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) {
+  // Account for conversion as part of getAttrs().
+  std::vector<NamedAttribute> noAttrsStorage;
+  auto segmentSize = builder.getNamedAttr("operandSegmentSizes",
+                                          builder.getDenseI32ArrayAttr({1, 1}));
+  noAttrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> noAttrs(noAttrsStorage);
+  std::vector<NamedAttribute> attrsStorage = this->attrStorage;
+  attrsStorage.push_back(segmentSize);
+  ArrayRef<NamedAttribute> attrs(attrsStorage);
+
+  // Test separate arg, separate param build method.
+  auto op = builder.create<test::TableGenBuildOp6>(
+      loc, f32Ty, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with no result types, default value of attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32},
+                                              ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test collective params build method.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32}, ValueRange{*cstI32});
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, noAttrs);
+
+  // Test build method with result types, supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(
+      loc, TypeRange{f32Ty}, ValueRange{*cstI32, *cstI32}, attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+
+  // Test build method with no result types and supplied attributes.
+  op = builder.create<test::TableGenBuildOp6>(loc, ValueRange{*cstI32, *cstI32},
+                                              attrs);
+  verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs);
+}
+
 } // namespace mlir

``````````

</details>


https://github.com/llvm/llvm-project/pull/90430


More information about the Mlir-commits mailing list