[Mlir-commits] [mlir] 12d2f75 - [mlir][ods] OpFormat: fix type inference issues

Jeff Niu llvmlistbot at llvm.org
Mon Aug 29 09:28:48 PDT 2022


Author: Jeff Niu
Date: 2022-08-29T09:28:40-07:00
New Revision: 12d2f75aedf858e460c63039008e1874eaf06e85

URL: https://github.com/llvm/llvm-project/commit/12d2f75aedf858e460c63039008e1874eaf06e85
DIFF: https://github.com/llvm/llvm-project/commit/12d2f75aedf858e460c63039008e1874eaf06e85.diff

LOG: [mlir][ods] OpFormat: fix type inference issues

This patch fixes issues with generating assembly format parsers for
operations that use the `operands` directive or which have unnamed
arguments or results.

This patch also fixes a function in `OpAsmParser` that always produced
an error when trying to resolve variadic operands with the same type.

Fixes #51841

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D131627

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 2d3431a8b80a0..332b242129971 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1356,37 +1356,25 @@ class OpAsmParser : public AsmParser {
   /// Resolve a list of operands to SSA values, emitting an error on failure, or
   /// appending the results to the list on success. This method should be used
   /// when all operands have the same type.
-  ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type,
+  template <typename Operands = ArrayRef<UnresolvedOperand>>
+  ParseResult resolveOperands(Operands &&operands, Type type,
                               SmallVectorImpl<Value> &result) {
-    for (auto elt : operands)
-      if (resolveOperand(elt, type, result))
+    for (const UnresolvedOperand &operand : operands)
+      if (resolveOperand(operand, type, result))
         return failure();
     return success();
   }
+  template <typename Operands = ArrayRef<UnresolvedOperand>>
+  ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
+                              SmallVectorImpl<Value> &result) {
+    return resolveOperands(std::forward<Operands>(operands), type, result);
+  }
 
   /// Resolve a list of operands and a list of operand types to SSA values,
   /// emitting an error and returning failure, or appending the results
   /// to the list on success.
-  ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands,
-                              ArrayRef<Type> types, SMLoc loc,
-                              SmallVectorImpl<Value> &result) {
-    if (operands.size() != types.size())
-      return emitError(loc)
-             << operands.size() << " operands present, but expected "
-             << types.size();
-
-    for (unsigned i = 0, e = operands.size(); i != e; ++i)
-      if (resolveOperand(operands[i], types[i], result))
-        return failure();
-    return success();
-  }
-  template <typename Operands>
-  ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
-                              SmallVectorImpl<Value> &result) {
-    return resolveOperands(std::forward<Operands>(operands),
-                           ArrayRef<Type>(type), loc, result);
-  }
-  template <typename Operands, typename Types>
+  template <typename Operands = ArrayRef<UnresolvedOperand>,
+            typename Types = ArrayRef<Type>>
   std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
   resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
                   SmallVectorImpl<Value> &result) {
@@ -1396,8 +1384,8 @@ class OpAsmParser : public AsmParser {
       return emitError(loc)
              << operandSize << " operands present, but expected " << typeSize;
 
-    for (auto it : llvm::zip(operands, types))
-      if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
+    for (auto [operand, type] : llvm::zip(operands, types))
+      if (resolveOperand(operand, type, result))
         return failure();
     return success();
   }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9567aa153a0a8..18775b7841bc0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2130,7 +2130,7 @@ def FormatInferVariadicTypeFromNonVariadic
               [SameOperandsAndResultType]> {
   let arguments = (ins Variadic<AnyType>:$args);
   let results = (outs AnyType:$result);
-  let assemblyFormat = "$args attr-dict `:` type($result)";
+  let assemblyFormat = "operands attr-dict `:` type($result)";
 }
 
 def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> {

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 6d21afdfe6de7..d546cb97658f2 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -172,12 +172,12 @@ def ZCoverageValidC : TestFormat_Op<[{
 // Check that we can infer type equalities from certain traits.
 def ZCoverageValidD : TestFormat_Op<[{
   operands type($result) attr-dict
-}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
+}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef)>,
      Results<(outs AnyMemRef:$result)>;
 def ZCoverageValidE : TestFormat_Op<[{
   $operand type($operand) attr-dict
 }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
-     Results<(outs AnyMemRef:$result)>;
+     Results<(outs AnyMemRef)>;
 def ZCoverageValidF : TestFormat_Op<[{
   operands type($other) attr-dict
 }], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index c1bb87712323d..304af57292f4f 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1396,6 +1396,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
         body << tgfmt(*tform, &fmtContext);
       } else {
         body << var->name << "Types";
+        if (!var->isVariadic())
+          body << "[0]";
       }
     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
       if (Optional<StringRef> tform = resolver.getVarTransformer())
@@ -1484,8 +1486,8 @@ void OperationFormat::genParserOperandTypeResolution(
       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
     }
 
-    body << ", allOperandLoc, result.operands))\n"
-         << "    return ::mlir::failure();\n";
+    body << ", allOperandLoc, result.operands))\n    return "
+            "::mlir::failure();\n";
     return;
   }
 
@@ -1499,25 +1501,8 @@ void OperationFormat::genParserOperandTypeResolution(
     TypeResolution &operandType = operandTypes[i];
     emitTypeResolver(operandType, operand.name);
 
-    // If the type is resolved by a non-variadic variable, index into the
-    // resolved type list. This allows for resolving the types of a variadic
-    // operand list from a non-variadic variable.
-    bool verifyOperandAndTypeSize = true;
-    if (auto *resolverVar = operandType.getVariable()) {
-      if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
-        body << "[0]";
-        verifyOperandAndTypeSize = false;
-      }
-    } else {
-      verifyOperandAndTypeSize = !operandType.getBuilderIdx();
-    }
-
-    // Check to see if the sizes between the types and operands must match. If
-    // they do, provide the operand location to select the proper resolution
-    // overload.
-    if (verifyOperandAndTypeSize)
-      body << ", " << operand.name << "OperandsLoc";
-    body << ", result.operands))\n    return ::mlir::failure();\n";
+    body << ", " << operand.name
+         << "OperandsLoc, result.operands))\n    return ::mlir::failure();\n";
   }
 }
 
@@ -2731,11 +2716,11 @@ void OpFormatParser::handleSameTypesConstraint(
 
   // Set the resolvers for each operand and result.
   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
-    if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
+    if (!seenOperandTypes.test(i))
       variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
   if (includeResults) {
     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
-      if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
+      if (!seenResultTypes.test(i))
         variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
   }
 }


        


More information about the Mlir-commits mailing list