[Mlir-commits] [mlir] 12d2f75 - [mlir][ods] OpFormat: fix type inference issues
Jeff Niu
llvmlistbot at llvm.org
Mon Aug 29 09:28:48 PDT 2022
Author: Jeff Niu
Date: 2022-08-29T09:28:40-07:00
New Revision: 12d2f75aedf858e460c63039008e1874eaf06e85
URL: https://github.com/llvm/llvm-project/commit/12d2f75aedf858e460c63039008e1874eaf06e85
DIFF: https://github.com/llvm/llvm-project/commit/12d2f75aedf858e460c63039008e1874eaf06e85.diff
LOG: [mlir][ods] OpFormat: fix type inference issues
This patch fixes issues with generating assembly format parsers for
operations that use the `operands` directive or which have unnamed
arguments or results.
This patch also fixes a function in `OpAsmParser` that always produced
an error when trying to resolve variadic operands with the same type.
Fixes #51841
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D131627
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 2d3431a8b80a0..332b242129971 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1356,37 +1356,25 @@ class OpAsmParser : public AsmParser {
/// Resolve a list of operands to SSA values, emitting an error on failure, or
/// appending the results to the list on success. This method should be used
/// when all operands have the same type.
- ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type,
+ template <typename Operands = ArrayRef<UnresolvedOperand>>
+ ParseResult resolveOperands(Operands &&operands, Type type,
SmallVectorImpl<Value> &result) {
- for (auto elt : operands)
- if (resolveOperand(elt, type, result))
+ for (const UnresolvedOperand &operand : operands)
+ if (resolveOperand(operand, type, result))
return failure();
return success();
}
+ template <typename Operands = ArrayRef<UnresolvedOperand>>
+ ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
+ SmallVectorImpl<Value> &result) {
+ return resolveOperands(std::forward<Operands>(operands), type, result);
+ }
/// Resolve a list of operands and a list of operand types to SSA values,
/// emitting an error and returning failure, or appending the results
/// to the list on success.
- ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands,
- ArrayRef<Type> types, SMLoc loc,
- SmallVectorImpl<Value> &result) {
- if (operands.size() != types.size())
- return emitError(loc)
- << operands.size() << " operands present, but expected "
- << types.size();
-
- for (unsigned i = 0, e = operands.size(); i != e; ++i)
- if (resolveOperand(operands[i], types[i], result))
- return failure();
- return success();
- }
- template <typename Operands>
- ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
- SmallVectorImpl<Value> &result) {
- return resolveOperands(std::forward<Operands>(operands),
- ArrayRef<Type>(type), loc, result);
- }
- template <typename Operands, typename Types>
+ template <typename Operands = ArrayRef<UnresolvedOperand>,
+ typename Types = ArrayRef<Type>>
std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
SmallVectorImpl<Value> &result) {
@@ -1396,8 +1384,8 @@ class OpAsmParser : public AsmParser {
return emitError(loc)
<< operandSize << " operands present, but expected " << typeSize;
- for (auto it : llvm::zip(operands, types))
- if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
+ for (auto [operand, type] : llvm::zip(operands, types))
+ if (resolveOperand(operand, type, result))
return failure();
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9567aa153a0a8..18775b7841bc0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2130,7 +2130,7 @@ def FormatInferVariadicTypeFromNonVariadic
[SameOperandsAndResultType]> {
let arguments = (ins Variadic<AnyType>:$args);
let results = (outs AnyType:$result);
- let assemblyFormat = "$args attr-dict `:` type($result)";
+ let assemblyFormat = "operands attr-dict `:` type($result)";
}
def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> {
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 6d21afdfe6de7..d546cb97658f2 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -172,12 +172,12 @@ def ZCoverageValidC : TestFormat_Op<[{
// Check that we can infer type equalities from certain traits.
def ZCoverageValidD : TestFormat_Op<[{
operands type($result) attr-dict
-}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
+}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef)>,
Results<(outs AnyMemRef:$result)>;
def ZCoverageValidE : TestFormat_Op<[{
$operand type($operand) attr-dict
}], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>,
- Results<(outs AnyMemRef:$result)>;
+ Results<(outs AnyMemRef)>;
def ZCoverageValidF : TestFormat_Op<[{
operands type($other) attr-dict
}], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>;
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index c1bb87712323d..304af57292f4f 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1396,6 +1396,8 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
body << tgfmt(*tform, &fmtContext);
} else {
body << var->name << "Types";
+ if (!var->isVariadic())
+ body << "[0]";
}
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
if (Optional<StringRef> tform = resolver.getVarTransformer())
@@ -1484,8 +1486,8 @@ void OperationFormat::genParserOperandTypeResolution(
emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
}
- body << ", allOperandLoc, result.operands))\n"
- << " return ::mlir::failure();\n";
+ body << ", allOperandLoc, result.operands))\n return "
+ "::mlir::failure();\n";
return;
}
@@ -1499,25 +1501,8 @@ void OperationFormat::genParserOperandTypeResolution(
TypeResolution &operandType = operandTypes[i];
emitTypeResolver(operandType, operand.name);
- // If the type is resolved by a non-variadic variable, index into the
- // resolved type list. This allows for resolving the types of a variadic
- // operand list from a non-variadic variable.
- bool verifyOperandAndTypeSize = true;
- if (auto *resolverVar = operandType.getVariable()) {
- if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
- body << "[0]";
- verifyOperandAndTypeSize = false;
- }
- } else {
- verifyOperandAndTypeSize = !operandType.getBuilderIdx();
- }
-
- // Check to see if the sizes between the types and operands must match. If
- // they do, provide the operand location to select the proper resolution
- // overload.
- if (verifyOperandAndTypeSize)
- body << ", " << operand.name << "OperandsLoc";
- body << ", result.operands))\n return ::mlir::failure();\n";
+ body << ", " << operand.name
+ << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
}
}
@@ -2731,11 +2716,11 @@ void OpFormatParser::handleSameTypesConstraint(
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
- if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
+ if (!seenOperandTypes.test(i))
variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
- if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
+ if (!seenResultTypes.test(i))
variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
}
}
More information about the Mlir-commits
mailing list