[Mlir-commits] [mlir] d85eb4e - [AsmParser] Introduce a new "Argument" abstraction + supporting logic
Chris Lattner
llvmlistbot at llvm.org
Fri Apr 29 12:22:25 PDT 2022
Author: Chris Lattner
Date: 2022-04-29T12:19:34-07:00
New Revision: d85eb4e2d62e51645922ec17678a319b3c7d872c
URL: https://github.com/llvm/llvm-project/commit/d85eb4e2d62e51645922ec17678a319b3c7d872c
DIFF: https://github.com/llvm/llvm-project/commit/d85eb4e2d62e51645922ec17678a319b3c7d872c.diff
LOG: [AsmParser] Introduce a new "Argument" abstraction + supporting logic
MLIR has a common pattern for "arguments" that uses syntax
like `%x : i32 {attrs} loc("sourceloc")` which is implemented
in adhoc ways throughout the codebase. The approach this uses
is verbose (because it is implemented with parallel arrays) and
inconsistent (e.g. lots of things drop source location info).
Solve this by introducing OpAsmParser::Argument and make addRegion
(which sets up BlockArguments for the region) take it. Convert the
world to propagating this down. This means that we correctly
capture and propagate source location information in a lot more
cases (e.g. see the affine.for testcase example), and it also
simplifies much code.
Differential Revision: https://reviews.llvm.org/D124649
Added:
Modified:
flang/lib/Optimizer/Dialect/FIROps.cpp
mlir/include/mlir/IR/FunctionImplementation.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
mlir/lib/Dialect/SCF/SCF.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/IR/locations.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 29cbbcfb9d288..2104173065461 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1200,8 +1200,8 @@ mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser,
result.addRegion();
} else {
// Parse the optional initializer body.
- auto parseResult = parser.parseOptionalRegion(
- *result.addRegion(), /*arguments=*/llvm::None, /*argTypes=*/llvm::None);
+ auto parseResult =
+ parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{});
if (parseResult.hasValue() && mlir::failed(*parseResult))
return mlir::failure();
}
@@ -1562,9 +1562,9 @@ void fir::IterWhileOp::build(mlir::OpBuilder &builder,
mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
- mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
- if (parser.parseLParen() ||
- parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
+ mlir::OpAsmParser::Argument inductionVariable, iterateVar;
+ mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput;
+ if (parser.parseLParen() || parser.parseArgument(inductionVariable) ||
parser.parseEqual())
return mlir::failure();
@@ -1577,22 +1577,18 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
parser.resolveOperand(ub, indexType, result.operands) ||
parser.parseKeyword("step") || parser.parseOperand(step) ||
parser.parseRParen() ||
- parser.resolveOperand(step, indexType, result.operands))
- return mlir::failure();
-
- mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput;
- if (parser.parseKeyword("and") || parser.parseLParen() ||
- parser.parseOperand(iterateVar, /*allowResultNumber=*/false) ||
- parser.parseEqual() || parser.parseOperand(iterateInput) ||
- parser.parseRParen() ||
+ parser.resolveOperand(step, indexType, result.operands) ||
+ parser.parseKeyword("and") || parser.parseLParen() ||
+ parser.parseArgument(iterateVar) || parser.parseEqual() ||
+ parser.parseOperand(iterateInput) || parser.parseRParen() ||
parser.resolveOperand(iterateInput, i1Type, result.operands))
return mlir::failure();
// Parse the initial iteration arguments.
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs;
auto prependCount = false;
// Induction variable.
+ llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
regionArgs.push_back(inductionVariable);
regionArgs.push_back(iterateVar);
@@ -1652,7 +1648,10 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
parser.getNameLoc(),
"mismatch in number of loop-carried values and defined values");
- if (parser.parseRegion(*body, regionArgs, argTypes))
+ for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
+ regionArgs[i].type = argTypes[i];
+
+ if (parser.parseRegion(*body, regionArgs))
return mlir::failure();
fir::IterWhileOp::ensureTerminator(*body, builder, result.location);
@@ -1876,10 +1875,10 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
- mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
+ mlir::OpAsmParser::Argument inductionVariable;
+ mlir::OpAsmParser::UnresolvedOperand lb, ub, step;
// Parse the induction variable followed by '='.
- if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
- parser.parseEqual())
+ if (parser.parseArgument(inductionVariable) || parser.parseEqual())
return mlir::failure();
// Parse loop bounds.
@@ -1896,7 +1895,8 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
result.addAttribute("unordered", builder.getUnitAttr());
// Parse the optional initial iteration arguments.
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> regionArgs, operands;
+ llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
llvm::SmallVector<mlir::Type> argTypes;
bool prependCount = false;
regionArgs.push_back(inductionVariable);
@@ -1939,8 +1939,10 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
return parser.emitError(
parser.getNameLoc(),
"mismatch in number of loop-carried values and defined values");
+ for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
+ regionArgs[i].type = argTypes[i];
- if (parser.parseRegion(*body, regionArgs, argTypes))
+ if (parser.parseRegion(*body, regionArgs))
return mlir::failure();
DoLoopOp::ensureTerminator(*body, builder, result.location);
diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h
index 04684e8c346a0..5265f781d1a77 100644
--- a/mlir/include/mlir/IR/FunctionImplementation.h
+++ b/mlir/include/mlir/IR/FunctionImplementation.h
@@ -41,8 +41,8 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<DictionaryAttr> resultAttrs);
void addArgAndResultAttrs(Builder &builder, OperationState &result,
- ArrayRef<NamedAttrList> argAttrs,
- ArrayRef<NamedAttrList> resultAttrs);
+ ArrayRef<OpAsmParser::Argument> argAttrs,
+ ArrayRef<DictionaryAttr> resultAttrs);
/// Callback type for `parseFunctionOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
@@ -52,26 +52,20 @@ void addArgAndResultAttrs(Builder &builder, OperationState &result,
using FuncTypeBuilder = function_ref<Type(
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
-/// Parses function arguments using `parser`. The `allowVariadic` argument
-/// indicates whether functions with variadic arguments are supported. The
-/// trailing arguments are populated by this function with names, types,
-/// attributes and locations of the arguments.
-ParseResult parseFunctionArgumentList(
- OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
- SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
- bool &isVariadic);
-
/// Parses a function signature using `parser`. The `allowVariadic` argument
/// indicates whether functions with variadic arguments are supported. The
/// trailing arguments are populated by this function with names, types,
/// attributes and locations of the arguments and those of the results.
-ParseResult parseFunctionSignature(
- OpAsmParser &parser, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
- SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
- bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<NamedAttrList> &resultAttrs);
+ParseResult
+parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
+ SmallVectorImpl<OpAsmParser::Argument> &arguments,
+ bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs);
+
+/// Get a function type corresponding to an array of arguments (which have
+/// types) and a set of result types.
+Type getFunctionType(Builder &builder, ArrayRef<OpAsmParser::Argument> argAttrs,
+ ArrayRef<Type> resultTypes);
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0399f10825696..0a1311832a6d5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -633,14 +633,14 @@ class AsmParser {
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
- T getChecked(SMLoc loc, ParamsT &&... params) {
+ T getChecked(SMLoc loc, ParamsT &&...params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT>
- T getChecked(ParamsT &&... params) {
+ T getChecked(ParamsT &&...params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}
@@ -1093,7 +1093,6 @@ class OpAsmParser : public AsmParser {
SMLoc location; // Location of the token.
StringRef name; // Value name, e.g. %42 or %abc
unsigned number; // Number, e.g. 12 for an operand like %xyz#12
- Optional<Location> sourceLoc; // Source location specifier if present.
};
/// Parse
diff erent components, viz., use-info of operand(s), successor(s),
@@ -1219,34 +1218,64 @@ class OpAsmParser : public AsmParser {
SmallVectorImpl<UnresolvedOperand> &symbOperands,
AffineExpr &expr) = 0;
+ //===--------------------------------------------------------------------===//
+ // Argument Parsing
+ //===--------------------------------------------------------------------===//
+
+ struct Argument {
+ UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
+ Type type; // Type.
+ DictionaryAttr attrs; // Attributes if present.
+ Optional<Location> sourceLoc; // Source location specifier if present.
+ };
+
+ /// Parse a single argument with the following syntax:
+ ///
+ /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
+ ///
+ /// If `allowType` is false or `allowAttrs` are false then the respective
+ /// parts of the grammar are not parsed.
+ virtual ParseResult parseArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false) = 0;
+
+ /// Parse a single argument if present.
+ virtual OptionalParseResult
+ parseOptionalArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false) = 0;
+
+ /// Parse zero or more arguments with a specified surrounding delimiter.
+ virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
+ Delimiter delimiter = Delimiter::None,
+ bool allowType = false,
+ bool allowAttrs = false) = 0;
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
/// Parses a region. Any parsed blocks are appended to 'region' and must be
/// moved to the op regions after the op is created. The first block of the
- /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
- /// set to true, the argument names are allowed to shadow the names of other
- /// existing SSA values defined above the region scope. 'enableNameShadowing'
- /// can only be set to true for regions attached to operations that are
- /// 'IsolatedFromAbove'.
+ /// region takes 'arguments'.
+ ///
+ /// If 'enableNameShadowing' is set to true, the argument names are allowed to
+ /// shadow the names of other existing SSA values defined above the region
+ /// scope. 'enableNameShadowing' can only be set to true for regions attached
+ /// to operations that are 'IsolatedFromAbove'.
virtual ParseResult parseRegion(Region ®ion,
- ArrayRef<UnresolvedOperand> arguments = {},
- ArrayRef<Type> argTypes = {},
+ ArrayRef<Argument> arguments = {},
bool enableNameShadowing = false) = 0;
/// Parses a region if present.
- virtual OptionalParseResult parseOptionalRegion(
- Region ®ion, ArrayRef<UnresolvedOperand> arguments = {},
- ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
+ virtual OptionalParseResult
+ parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {},
+ bool enableNameShadowing = false) = 0;
/// Parses a region if present. If the region is present, a new region is
/// allocated and placed in `region`. If no region is present or on failure,
/// `region` remains untouched.
virtual OptionalParseResult
parseOptionalRegion(std::unique_ptr<Region> ®ion,
- ArrayRef<UnresolvedOperand> arguments = {},
- ArrayRef<Type> argTypes = {},
+ ArrayRef<Argument> arguments = {},
bool enableNameShadowing = false) = 0;
//===--------------------------------------------------------------------===//
@@ -1269,7 +1298,7 @@ class OpAsmParser : public AsmParser {
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...)
- ParseResult parseAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
+ ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
SmallVectorImpl<UnresolvedOperand> &rhs) {
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
if (!result.hasValue())
@@ -1278,26 +1307,8 @@ class OpAsmParser : public AsmParser {
}
virtual OptionalParseResult
- parseOptionalAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs,
+ parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
-
- /// Parse a list of assignments of the form
- /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
- ParseResult
- parseAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
- SmallVectorImpl<UnresolvedOperand> &rhs,
- SmallVectorImpl<Type> &types) {
- OptionalParseResult result =
- parseOptionalAssignmentListWithTypes(lhs, rhs, types);
- if (!result.hasValue())
- return emitError(getCurrentLocation(), "expected '('");
- return result.getValue();
- }
-
- virtual OptionalParseResult
- parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
- SmallVectorImpl<UnresolvedOperand> &rhs,
- SmallVectorImpl<Type> &types) = 0;
};
//===--------------------------------------------------------------------===//
@@ -1339,7 +1350,6 @@ class OpAsmDialectInterface
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
return AliasResult::NoAlias;
}
-
};
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index ad04c6f7e210c..2fb385dd94c6f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1431,10 +1431,10 @@ static ParseResult parseBound(bool isLower, OperationState &result,
ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand inductionVariable;
+ OpAsmParser::Argument inductionVariable;
+ inductionVariable.type = builder.getIndexType();
// Parse the induction variable followed by '='.
- if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
- parser.parseEqual())
+ if (parser.parseArgument(inductionVariable) || parser.parseEqual())
return failure();
// Parse loop bounds.
@@ -1463,8 +1463,10 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
}
// Parse the optional initial iteration arguments.
- SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
- SmallVector<Type, 4> argTypes;
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Induction variable.
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
@@ -1473,23 +1475,23 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseArrowTypeList(result.types))
return failure();
// Resolve input operands.
- for (auto operandType : llvm::zip(operands, result.types))
- if (parser.resolveOperand(std::get<0>(operandType),
- std::get<1>(operandType), result.operands))
+ for (auto argOperandType :
+ llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ Type type = std::get<2>(argOperandType);
+ std::get<0>(argOperandType).type = type;
+ if (parser.resolveOperand(std::get<1>(argOperandType), type,
+ result.operands))
return failure();
+ }
}
- // Induction variable.
- Type indexType = builder.getIndexType();
- argTypes.push_back(indexType);
- // Loop carried variables.
- argTypes.append(result.types.begin(), result.types.end());
+
// Parse the body region.
Region *body = result.addRegion();
- if (regionArgs.size() != argTypes.size())
+ if (regionArgs.size() != result.types.size() + 1)
return parser.emitError(
parser.getNameLoc(),
"mismatch between the number of loop-carried values and results");
- if (parser.parseRegion(*body, regionArgs, argTypes))
+ if (parser.parseRegion(*body, regionArgs))
return failure();
AffineForOp::ensureTerminator(*body, builder, result.location);
@@ -1548,7 +1550,8 @@ unsigned AffineForOp::getNumIterOperands() {
void AffineForOp::print(OpAsmPrinter &p) {
p << ' ';
- p.printOperand(getBody()->getArgument(0));
+ p.printRegionArgument(getBody()->getArgument(0), /*argAtrs=*/{},
+ /*omitType=*/true);
p << " = ";
printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
p << " to ";
@@ -3527,9 +3530,8 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
- SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
- if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
- /*allowResultNumber=*/false) ||
+ SmallVector<OpAsmParser::Argument, 4> ivs;
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
parser.parseEqual() ||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
parser.parseKeyword("to") ||
@@ -3600,8 +3602,9 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
// Now parse the body.
Region *body = result.addRegion();
- SmallVector<Type, 4> types(ivs.size(), indexType);
- if (parser.parseRegion(*body, ivs, types) ||
+ for (auto &iv : ivs)
+ iv.type = indexType;
+ if (parser.parseRegion(*body, ivs) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index fb64c39bf2e2d..4e0e890a73b4b 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -178,21 +178,19 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse async value operands (%value as %unwrapped : !async.value<!type>).
SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> unwrappedArgs;
+ SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
SmallVector<Type, 4> valueTypes;
- SmallVector<Type, 4> unwrappedTypes;
// Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
auto parseAsyncValueArg = [&]() -> ParseResult {
if (parser.parseOperand(valueArgs.emplace_back()) ||
parser.parseKeyword("as") ||
- parser.parseOperand(unwrappedArgs.emplace_back()) ||
+ parser.parseArgument(unwrappedArgs.emplace_back()) ||
parser.parseColonType(valueTypes.emplace_back()))
return failure();
auto valueTy = valueTypes.back().dyn_cast<ValueType>();
- unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
-
+ unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
return success();
};
@@ -227,12 +225,7 @@ ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse asynchronous region.
Region *body = result.addRegion();
- if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
- /*argTypes=*/{unwrappedTypes},
- /*enableNameShadowing=*/false))
- return failure();
-
- return success();
+ return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
}
LogicalResult ExecuteOp::verifyRegions() {
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 4cd5b1f5a0d60..9d297f1bc35f3 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -622,8 +622,17 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
Type index = parser.getBuilder().getIndexType();
SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
LaunchOp::kNumConfigRegionAttributes, index);
+
+ SmallVector<OpAsmParser::Argument> regionArguments;
+ for (auto ssaValueAndType : llvm::zip(regionArgs, dataTypes)) {
+ OpAsmParser::Argument arg;
+ arg.ssaName = std::get<0>(ssaValueAndType);
+ arg.type = std::get<1>(ssaValueAndType);
+ regionArguments.push_back(arg);
+ }
+
Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs, dataTypes) ||
+ if (parser.parseRegion(*body, regionArguments) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
@@ -758,11 +767,16 @@ static ParseResult parseLaunchFuncOperands(
SmallVectorImpl<Type> &argTypes) {
if (parser.parseOptionalKeyword("args"))
return success();
- SmallVector<NamedAttrList> argAttrs;
- bool isVariadic = false;
- return function_interface_impl::parseFunctionArgumentList(
- parser, /*allowAttributes=*/false,
- /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic);
+
+ SmallVector<OpAsmParser::Argument> args;
+ if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true))
+ return failure();
+ for (auto &arg : args) {
+ argNames.push_back(arg.ssaName);
+ argTypes.push_back(arg.type);
+ }
+ return success();
}
static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
@@ -779,8 +793,6 @@ static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
printer << ")";
}
-//
-
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
@@ -852,32 +864,13 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
/// keyword provided as argument.
static ParseResult
parseAttributions(OpAsmParser &parser, StringRef keyword,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &args,
- SmallVectorImpl<Type> &argTypes) {
+ SmallVectorImpl<OpAsmParser::Argument> &args) {
// If we could not parse the keyword, just assume empty list and succeed.
if (failed(parser.parseOptionalKeyword(keyword)))
return success();
- if (failed(parser.parseLParen()))
- return failure();
-
- // Early exit for an empty list.
- if (succeeded(parser.parseOptionalRParen()))
- return success();
-
- do {
- OpAsmParser::UnresolvedOperand arg;
- Type type;
-
- if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
- parser.parseColonType(type))
- return failure();
-
- args.push_back(arg);
- argTypes.push_back(type);
- } while (succeeded(parser.parseOptionalComma()));
-
- return parser.parseRParen();
+ return parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
+ /*allowType=*/true);
}
/// Parses a GPU function.
@@ -886,10 +879,8 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
/// (`->` function-result-list)? memory-attribution `kernel`?
/// function-attributes? region
ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
- SmallVector<NamedAttrList> argAttrs;
- SmallVector<NamedAttrList> resultAttrs;
- SmallVector<Type> argTypes;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
bool isVariadic;
@@ -901,34 +892,41 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
auto signatureLocation = parser.getCurrentLocation();
if (failed(function_interface_impl::parseFunctionSignature(
- parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
- isVariadic, resultTypes, resultAttrs)))
+ parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+ resultAttrs)))
return failure();
- if (entryArgs.empty() && !argTypes.empty())
+ if (!entryArgs.empty() && entryArgs[0].ssaName.name.empty())
return parser.emitError(signatureLocation)
<< "gpu.func requires named arguments";
// Construct the function type. More types will be added to the region, but
// not to the function type.
Builder &builder = parser.getBuilder();
+
+ SmallVector<Type> argTypes;
+ for (auto &arg : entryArgs)
+ argTypes.push_back(arg.type);
auto type = builder.getFunctionType(argTypes, resultTypes);
result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type));
+ function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
+ resultAttrs);
+
// Parse workgroup memory attributions.
if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(),
- entryArgs, argTypes)))
+ entryArgs)))
return failure();
// Store the number of operands we just parsed as the number of workgroup
// memory attributions.
- unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
+ unsigned numWorkgroupAttrs = entryArgs.size() - type.getNumInputs();
result.addAttribute(GPUFuncOp::getNumWorkgroupAttributionsAttrName(),
builder.getI64IntegerAttr(numWorkgroupAttrs));
// Parse private memory attributions.
- if (failed(parseAttributions(parser, GPUFuncOp::getPrivateKeyword(),
- entryArgs, argTypes)))
+ if (failed(
+ parseAttributions(parser, GPUFuncOp::getPrivateKeyword(), entryArgs)))
return failure();
// Parse the kernel attribute if present.
@@ -939,13 +937,11 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse attributes.
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
- function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs,
- resultAttrs);
// Parse the region. If no argument names were provided, take all names
// (including those of attributions) from the entry block.
auto *body = result.addRegion();
- return parser.parseRegion(*body, entryArgs, argTypes);
+ return parser.parseRegion(*body, entryArgs);
}
static void printAttributions(OpAsmPrinter &p, StringRef keyword,
@@ -1078,16 +1074,14 @@ void GPUModuleOp::build(OpBuilder &builder, OperationState &result,
ParseResult GPUModuleOp::parse(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
- result.attributes))
- return failure();
-
- // If module attributes are present, parse them.
- if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ result.attributes) ||
+ // If module attributes are present, parse them.
+ parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
// Parse the module body.
auto *body = result.addRegion();
- if (parser.parseRegion(*body, None, None))
+ if (parser.parseRegion(*body, {}))
return failure();
// Ensure that this module has a valid terminator.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c8188cac5a628..f80a42919df23 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2152,10 +2152,8 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
parser, result, LLVM::Linkage::External)));
StringAttr nameAttr;
- SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
- SmallVector<NamedAttrList> argAttrs;
- SmallVector<NamedAttrList> resultAttrs;
- SmallVector<Type> argTypes;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
bool isVariadic;
@@ -2163,10 +2161,13 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes) ||
function_interface_impl::parseFunctionSignature(
- parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs,
- isVariadic, resultTypes, resultAttrs))
+ parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
+ resultAttrs))
return failure();
+ SmallVector<Type> argTypes;
+ for (auto &arg : entryArgs)
+ argTypes.push_back(arg.type);
auto type =
buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
function_interface_impl::VariadicFlag(isVariadic));
@@ -2178,11 +2179,11 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
return failure();
function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result,
- argAttrs, resultAttrs);
+ entryArgs, resultAttrs);
auto *body = result.addRegion();
- OptionalParseResult parseResult = parser.parseOptionalRegion(
- *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
+ OptionalParseResult parseResult =
+ parser.parseOptionalRegion(*body, entryArgs);
return failure(parseResult.hasValue() && failed(*parseResult));
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 8bac54da30b15..d58c253bbfcea 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -799,10 +799,8 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
failed(parser.parseOptionalAttrDict(result.attributes)))
return failure();
- SmallVector<OpAsmParser::UnresolvedOperand, 8> regionOperands;
std::unique_ptr<Region> region = std::make_unique<Region>();
- SmallVector<Type, 8> operandTypes, regionTypes;
- if (parser.parseRegion(*region, regionOperands, regionTypes))
+ if (parser.parseRegion(*region, {}))
return failure();
result.addRegion(std::move(region));
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e7f2e03434899..d200fd0485c5f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -275,7 +275,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the body region.
- if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
return failure();
AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
result.location);
@@ -1215,7 +1215,7 @@ ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
return failure();
Region *body = result.addRegion();
- if (parser.parseRegion(*body, llvm::None, llvm::None) ||
+ if (parser.parseRegion(*body, {}) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e2fb9bbcc02a2..fa2becae7e637 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -523,20 +523,16 @@ parseWsLoopControl(OpAsmParser &parser, Region ®ion,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
// Parse an opening `(` followed by induction variables followed by `)`
- SmallVector<OpAsmParser::UnresolvedOperand> ivs;
- if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
- /*allowResultNumber=*/false))
- return failure();
-
- size_t numIVs = ivs.size();
+ SmallVector<OpAsmParser::Argument> ivs;
Type loopVarType;
- if (parser.parseColonType(loopVarType) ||
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
+ parser.parseColonType(loopVarType) ||
// Parse loop bounds.
parser.parseEqual() ||
- parser.parseOperandList(lowerBound, numIVs,
+ parser.parseOperandList(lowerBound, ivs.size(),
OpAsmParser::Delimiter::Paren) ||
parser.parseKeyword("to") ||
- parser.parseOperandList(upperBound, numIVs,
+ parser.parseOperandList(upperBound, ivs.size(),
OpAsmParser::Delimiter::Paren))
return failure();
@@ -545,15 +541,14 @@ parseWsLoopControl(OpAsmParser &parser, Region ®ion,
// Parse step values.
if (parser.parseKeyword("step") ||
- parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren))
+ parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
return failure();
// Now parse the body.
- loopVarTypes = SmallVector<Type>(numIVs, loopVarType);
- SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
- if (parser.parseRegion(region, blockArgs, loopVarTypes))
- return failure();
- return success();
+ loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
+ for (auto &iv : ivs)
+ iv.type = loopVarType;
+ return parser.parseRegion(region, ivs);
}
void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -582,33 +577,28 @@ void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion,
/// clause ::= TODO
ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse an opening `(` followed by induction variables followed by `)`
- SmallVector<OpAsmParser::UnresolvedOperand> ivs;
- if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
- /*allowResultNumber=*/false))
- return failure();
- int numIVs = static_cast<int>(ivs.size());
+ SmallVector<OpAsmParser::Argument> ivs;
Type loopVarType;
- if (parser.parseColonType(loopVarType))
- return failure();
- // Parse loop bounds.
- SmallVector<OpAsmParser::UnresolvedOperand> lower;
- if (parser.parseEqual() ||
- parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
- parser.resolveOperands(lower, loopVarType, result.operands))
- return failure();
- SmallVector<OpAsmParser::UnresolvedOperand> upper;
- if (parser.parseKeyword("to") ||
- parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
- parser.resolveOperands(upper, loopVarType, result.operands))
- return failure();
-
- // Parse step values.
- SmallVector<OpAsmParser::UnresolvedOperand> steps;
- if (parser.parseKeyword("step") ||
- parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
+ SmallVector<OpAsmParser::UnresolvedOperand> lower, upper, steps;
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
+ parser.parseColonType(loopVarType) ||
+ // Parse loop bounds.
+ parser.parseEqual() ||
+ parser.parseOperandList(lower, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(lower, loopVarType, result.operands) ||
+ parser.parseKeyword("to") ||
+ parser.parseOperandList(upper, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
+ parser.resolveOperands(upper, loopVarType, result.operands) ||
+ // Parse step values.
+ parser.parseKeyword("step") ||
+ parser.parseOperandList(steps, ivs.size(),
+ OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, loopVarType, result.operands))
return failure();
+ int numIVs = static_cast<int>(ivs.size());
SmallVector<int> segments{numIVs, numIVs, numIVs};
// TODO: Add parseClauses() when we support clauses
result.addAttribute("operand_segment_sizes",
@@ -616,11 +606,9 @@ ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
// Now parse the body.
Region *body = result.addRegion();
- SmallVector<Type> ivTypes(numIVs, loopVarType);
- SmallVector<OpAsmParser::UnresolvedOperand> blockArgs(ivs);
- if (parser.parseRegion(*body, blockArgs, ivTypes))
- return failure();
- return success();
+ for (auto &iv : ivs)
+ iv.type = loopVarType;
+ return parser.parseRegion(*body, ivs);
}
void SimdLoopOp::print(OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index 185b91eae935e..64a2be6a37778 100644
--- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -101,41 +101,29 @@ void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the loop variable followed by type.
- OpAsmParser::UnresolvedOperand loopVariable;
- Type loopVariableType;
- if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) ||
- parser.parseColonType(loopVariableType))
- return failure();
-
- // Parse the "in" keyword.
- if (parser.parseKeyword("in", " after loop variable"))
- return failure();
-
- // Parse the operand (value range).
+ OpAsmParser::Argument loopVariable;
OpAsmParser::UnresolvedOperand operandInfo;
- if (parser.parseOperand(operandInfo))
+ if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
+ parser.parseKeyword("in", " after loop variable") ||
+ // Parse the operand (value range).
+ parser.parseOperand(operandInfo))
return failure();
// Resolve the operand.
- Type rangeType = pdl::RangeType::get(loopVariableType);
+ Type rangeType = pdl::RangeType::get(loopVariable.type);
if (parser.resolveOperand(operandInfo, rangeType, result.operands))
return failure();
// Parse the body region.
Region *body = result.addRegion();
- if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
- return failure();
-
- // Parse the attribute dictionary.
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- // Parse the successor.
Block *successor;
- if (parser.parseArrow() || parser.parseSuccessor(successor))
+ if (parser.parseRegion(*body, loopVariable) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ // Parse the successor.
+ parser.parseArrow() || parser.parseSuccessor(successor))
return failure();
- result.addSuccessors(successor);
+ result.addSuccessors(successor);
return success();
}
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 299e60b4d4a9f..cee31916babf8 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -399,15 +399,16 @@ void ForOp::print(OpAsmPrinter &p) {
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
- // Parse the induction variable followed by '='.
- if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
- parser.parseEqual())
- return failure();
-
- // Parse loop bounds.
Type indexType = builder.getIndexType();
- if (parser.parseOperand(lb) ||
+
+ OpAsmParser::Argument inductionVariable;
+ inductionVariable.type = indexType;
+ OpAsmParser::UnresolvedOperand lb, ub, step;
+
+ // Parse the induction variable followed by '='.
+ if (parser.parseArgument(inductionVariable) || parser.parseEqual() ||
+ // Parse loop bounds.
+ parser.parseOperand(lb) ||
parser.resolveOperand(lb, indexType, result.operands) ||
parser.parseKeyword("to") || parser.parseOperand(ub) ||
parser.resolveOperand(ub, indexType, result.operands) ||
@@ -416,8 +417,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
return failure();
// Parse the optional initial iteration arguments.
- SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
- SmallVector<Type, 4> argTypes;
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
regionArgs.push_back(inductionVariable);
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
@@ -425,24 +426,26 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
if (parser.parseAssignmentList(regionArgs, operands) ||
parser.parseArrowTypeList(result.types))
return failure();
+
// Resolve input operands.
- for (auto operandType : llvm::zip(operands, result.types))
- if (parser.resolveOperand(std::get<0>(operandType),
- std::get<1>(operandType), result.operands))
+ for (auto argOperandType :
+ llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
+ Type type = std::get<2>(argOperandType);
+ std::get<0>(argOperandType).type = type;
+ if (parser.resolveOperand(std::get<1>(argOperandType), type,
+ result.operands))
return failure();
+ }
}
- // Induction variable.
- argTypes.push_back(indexType);
- // Loop carried variables
- argTypes.append(result.types.begin(), result.types.end());
- // Parse the body region.
- Region *body = result.addRegion();
- if (regionArgs.size() != argTypes.size())
+
+ if (regionArgs.size() != result.types.size() + 1)
return parser.emitError(
parser.getNameLoc(),
"mismatch in number of loop-carried values and defined values");
- if (parser.parseRegion(*body, regionArgs, argTypes))
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
return failure();
ForOp::ensureTerminator(*body, builder, result.location);
@@ -1975,9 +1978,8 @@ LogicalResult ParallelOp::verify() {
ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
- SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
- if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
- /*allowResultNumber=*/false))
+ SmallVector<OpAsmParser::Argument, 4> ivs;
+ if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren))
return failure();
// Parse loop bounds.
@@ -2016,8 +2018,9 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
// Now parse the body.
Region *body = result.addRegion();
- SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
- if (parser.parseRegion(*body, ivs, types))
+ for (auto &iv : ivs)
+ iv.type = builder.getIndexType();
+ if (parser.parseRegion(*body, ivs))
return failure();
// Set `operand_segment_sizes` attribute.
@@ -2370,7 +2373,8 @@ void WhileOp::getSuccessorRegions(Optional<unsigned> index,
/// assignment-list ::= assignment | assignment `,` assignment-list
/// assignment ::= ssa-value `=` ssa-value
ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 4> regionArgs, operands;
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
Region *before = result.addRegion();
Region *after = result.addRegion();
@@ -2399,10 +2403,13 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- return failure(
- parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
- parser.parseKeyword("do") || parser.parseRegion(*after) ||
- parser.parseOptionalAttrDictWithKeyword(result.attributes));
+ // Propagate the types into the region arguments.
+ for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
+ regionArgs[i].type = functionType.getInput(i);
+
+ return failure(parser.parseRegion(*before, regionArgs) ||
+ parser.parseKeyword("do") || parser.parseRegion(*after) ||
+ parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
/// Prints a `while` op.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 337bd41ce45fa..ca2925b903d2d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2193,10 +2193,8 @@ LogicalResult spirv::UConvertOp::verify() {
//===----------------------------------------------------------------------===//
ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
- SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
- SmallVector<NamedAttrList> argAttrs;
- SmallVector<NamedAttrList> resultAttrs;
- SmallVector<Type> argTypes;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
auto &builder = parser.getBuilder();
@@ -2209,10 +2207,13 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
// Parse the function signature.
bool isVariadic = false;
if (function_interface_impl::parseFunctionSignature(
- parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
- isVariadic, resultTypes, resultAttrs))
+ parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+ resultAttrs))
return failure();
+ SmallVector<Type> argTypes;
+ for (auto &arg : entryArgs)
+ argTypes.push_back(arg.type);
auto fnType = builder.getFunctionType(argTypes, resultTypes);
state.addAttribute(FunctionOpInterface::getTypeAttrName(),
TypeAttr::get(fnType));
@@ -2227,15 +2228,13 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
return failure();
// Add the attributes to the function arguments.
- assert(argAttrs.size() == argTypes.size());
assert(resultAttrs.size() == resultTypes.size());
- function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
+ function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
resultAttrs);
// Parse the optional function body.
auto *body = state.addRegion();
- OptionalParseResult result = parser.parseOptionalRegion(
- *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
+ OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
return failure(result.hasValue() && failed(*result));
}
diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp
index 45751daf811a1..a73da11a114cb 100644
--- a/mlir/lib/IR/FunctionImplementation.cpp
+++ b/mlir/lib/IR/FunctionImplementation.cpp
@@ -13,83 +13,61 @@
using namespace mlir;
-ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
- OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
- SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
- bool &isVariadic) {
- if (parser.parseLParen())
- return failure();
-
- // The argument list either has to consistently have ssa-id's followed by
- // types, or just be a type list. It isn't ok to sometimes have SSA ID's and
- // sometimes not.
- auto parseArgument = [&]() -> ParseResult {
- SMLoc loc = parser.getCurrentLocation();
-
- // Parse argument name if present.
- OpAsmParser::UnresolvedOperand argument;
- Type argumentType;
- auto hadSSAValue = parser.parseOptionalOperand(argument,
- /*allowResultNumber=*/false);
- if (hadSSAValue.hasValue()) {
- if (failed(hadSSAValue.getValue()))
- return failure(); // Argument was present but malformed.
-
- // Reject this if the preceding argument was missing a name.
- if (argNames.empty() && !argTypes.empty())
- return parser.emitError(loc, "expected type instead of SSA identifier");
-
- // Parse required type.
- if (parser.parseColonType(argumentType))
- return failure();
- } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
- isVariadic = true;
- return success();
- } else if (!argNames.empty()) {
- // Reject this if the preceding argument had a name.
- return parser.emitError(loc, "expected SSA identifier");
- } else if (parser.parseType(argumentType)) {
- return failure();
- }
-
- // Add the argument type.
- argTypes.push_back(argumentType);
-
- // Parse any argument attributes and source location information.
- NamedAttrList attrs;
- if (parser.parseOptionalAttrDict(attrs) ||
- parser.parseOptionalLocationSpecifier(argument.sourceLoc))
- return failure();
-
- if (!allowAttributes && !attrs.empty())
- return parser.emitError(loc, "expected arguments without attributes");
- argAttrs.push_back(attrs);
-
- // If we had an argument name, then remember the parsed argument.
- if (!argument.name.empty())
- argNames.push_back(argument);
- return success();
- };
+static ParseResult
+parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
+ SmallVectorImpl<OpAsmParser::Argument> &arguments,
+ bool &isVariadic) {
- // Parse the function arguments.
+ // Parse the function arguments. The argument list either has to consistently
+ // have ssa-id's followed by types, or just be a type list. It isn't ok to
+ // sometimes have SSA ID's and sometimes not.
isVariadic = false;
- if (failed(parser.parseOptionalRParen())) {
- do {
- unsigned numTypedArguments = argTypes.size();
- if (parseArgument())
- return failure();
-
- SMLoc loc = parser.getCurrentLocation();
- if (argTypes.size() == numTypedArguments &&
- succeeded(parser.parseOptionalComma()))
- return parser.emitError(
- loc, "variadic arguments must be in the end of the argument list");
- } while (succeeded(parser.parseOptionalComma()));
- parser.parseRParen();
- }
- return success();
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ // Ellipsis must be at end of the list.
+ if (isVariadic)
+ return parser.emitError(
+ parser.getCurrentLocation(),
+ "variadic arguments must be in the end of the argument list");
+
+ // Handle ellipsis as a special case.
+ if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
+ // This is a variadic designator.
+ isVariadic = true;
+ return success(); // Stop parsing arguments.
+ }
+ // Parse argument name if present.
+ OpAsmParser::Argument argument;
+ auto argPresent = parser.parseOptionalArgument(
+ argument, /*allowType=*/true, /*allowAttrs=*/true);
+ if (argPresent.hasValue()) {
+ if (failed(argPresent.getValue()))
+ return failure(); // Present but malformed.
+
+ // Reject this if the preceding argument was missing a name.
+ if (!arguments.empty() && arguments.back().ssaName.name.empty())
+ return parser.emitError(argument.ssaName.location,
+ "expected type instead of SSA identifier");
+
+ } else {
+ argument.ssaName.location = parser.getCurrentLocation();
+ // Otherwise we just have a type list without SSA names. Reject
+ // this if the preceding argument had a name.
+ if (!arguments.empty() && !arguments.back().ssaName.name.empty())
+ return parser.emitError(argument.ssaName.location,
+ "expected SSA identifier");
+
+ NamedAttrList attrs;
+ if (parser.parseType(argument.type) ||
+ parser.parseOptionalAttrDict(attrs) ||
+ parser.parseOptionalLocationSpecifier(argument.sourceLoc))
+ return failure();
+ argument.attrs = attrs.getDictionary(parser.getContext());
+ }
+ arguments.push_back(argument);
+ return success();
+ });
}
/// Parse a function result list.
@@ -103,7 +81,7 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
///
static ParseResult
parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<NamedAttrList> &resultAttrs) {
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
@@ -120,83 +98,74 @@ parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
return success();
// Parse individual function results.
- do {
- resultTypes.emplace_back();
- resultAttrs.emplace_back();
- if (parser.parseType(resultTypes.back()) ||
- parser.parseOptionalAttrDict(resultAttrs.back())) {
- return failure();
- }
- } while (succeeded(parser.parseOptionalComma()));
+ if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+ resultTypes.emplace_back();
+ resultAttrs.emplace_back();
+ NamedAttrList attrs;
+ if (parser.parseType(resultTypes.back()) ||
+ parser.parseOptionalAttrDict(attrs))
+ return failure();
+ resultAttrs.back() = attrs.getDictionary(parser.getContext());
+ return success();
+ }))
+ return failure();
+
return parser.parseRParen();
}
ParseResult mlir::function_interface_impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &argNames,
- SmallVectorImpl<Type> &argTypes, SmallVectorImpl<NamedAttrList> &argAttrs,
- bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
- SmallVectorImpl<NamedAttrList> &resultAttrs) {
- bool allowArgAttrs = true;
- if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
- argTypes, argAttrs, isVariadic))
+ SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
return failure();
if (succeeded(parser.parseOptionalArrow()))
return parseFunctionResultList(parser, resultTypes, resultAttrs);
return success();
}
-/// Implementation of `addArgAndResultAttrs` that is attribute list type
-/// agnostic.
-template <typename AttrListT, typename AttrArrayBuildFnT>
-static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
- ArrayRef<AttrListT> argAttrs,
- ArrayRef<AttrListT> resultAttrs,
- AttrArrayBuildFnT &&buildAttrArrayFn) {
- auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
+void mlir::function_interface_impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<DictionaryAttr> resultAttrs) {
+ auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+ return attrs && !attrs.empty();
+ };
+ // Convert the specified array of dictionary attrs (which may have null
+ // entries) to an ArrayAttr of dictionaries.
+ auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+ SmallVector<Attribute> attrs;
+ for (auto &dict : dictAttrs)
+ attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+ return builder.getArrayAttr(attrs);
+ };
// Add the attributes to the function arguments.
- if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
- ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
+ if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
result.addAttribute(function_interface_impl::getArgDictAttrName(),
- attrDicts);
- }
+ getArrayAttr(argAttrs));
+
// Add the attributes to the function results.
- if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
- ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
+ if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
result.addAttribute(function_interface_impl::getResultDictAttrName(),
- attrDicts);
- }
+ getArrayAttr(resultAttrs));
}
void mlir::function_interface_impl::addArgAndResultAttrs(
- Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+ Builder &builder, OperationState &result,
+ ArrayRef<OpAsmParser::Argument> args,
ArrayRef<DictionaryAttr> resultAttrs) {
- auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
- return ArrayRef<Attribute>(attrs.data(), attrs.size());
- };
- addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
-}
-void mlir::function_interface_impl::addArgAndResultAttrs(
- Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
- ArrayRef<NamedAttrList> resultAttrs) {
- MLIRContext *context = builder.getContext();
- auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
- return llvm::to_vector<8>(
- llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
- return attrList.getDictionary(context);
- }));
- };
- addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
+ SmallVector<DictionaryAttr> argAttrs;
+ for (const auto &arg : args)
+ argAttrs.push_back(arg.attrs);
+ addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
}
ParseResult mlir::function_interface_impl::parseFunctionOp(
OpAsmParser &parser, OperationState &result, bool allowVariadic,
FuncTypeBuilder funcTypeBuilder) {
- SmallVector<OpAsmParser::UnresolvedOperand> entryArgs;
- SmallVector<NamedAttrList> argAttrs;
- SmallVector<NamedAttrList> resultAttrs;
- SmallVector<Type> argTypes;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<DictionaryAttr> resultAttrs;
SmallVector<Type> resultTypes;
auto &builder = parser.getBuilder();
@@ -212,11 +181,15 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
// Parse the function signature.
SMLoc signatureLocation = parser.getCurrentLocation();
bool isVariadic = false;
- if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
- argAttrs, isVariadic, resultTypes, resultAttrs))
+ if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic,
+ resultTypes, resultAttrs))
return failure();
std::string errorMessage;
+ SmallVector<Type> argTypes;
+ argTypes.reserve(entryArgs.size());
+ for (auto &arg : entryArgs)
+ argTypes.push_back(arg.type);
Type type = funcTypeBuilder(builder, argTypes, resultTypes,
VariadicFlag(isVariadic), errorMessage);
if (!type) {
@@ -246,17 +219,16 @@ ParseResult mlir::function_interface_impl::parseFunctionOp(
result.attributes.append(parsedAttributes);
// Add the attributes to the function arguments.
- assert(argAttrs.size() == argTypes.size());
assert(resultAttrs.size() == resultTypes.size());
- addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+ addArgAndResultAttrs(builder, result, entryArgs, resultAttrs);
// Parse the optional function body. The printer will not print the body if
// its empty, so disallow parsing of empty body in the parser.
auto *body = result.addRegion();
SMLoc loc = parser.getCurrentLocation();
- OptionalParseResult parseResult = parser.parseOptionalRegion(
- *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
- /*enableNameShadowing=*/false);
+ OptionalParseResult parseResult =
+ parser.parseOptionalRegion(*body, entryArgs,
+ /*enableNameShadowing=*/false);
if (parseResult.hasValue()) {
if (failed(*parseResult))
return failure();
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 562b4fe7e5113..7a8b2431c4152 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -301,11 +301,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
};
- if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
- " in attribute dictionary"))
- return failure();
-
- return success();
+ return parseCommaSeparatedList(Delimiter::Braces, parseElt,
+ " in attribute dictionary");
}
/// Parse a float attribute.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index c50cd31badc51..722039ec9265a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -249,6 +249,7 @@ class OperationParser : public Parser {
//===--------------------------------------------------------------------===//
using UnresolvedOperand = OpAsmParser::UnresolvedOperand;
+ using Argument = OpAsmParser::Argument;
struct DeferredLocInfo {
SMLoc loc;
@@ -364,16 +365,13 @@ class OperationParser : public Parser {
/// Parse a region into 'region' with the provided entry block arguments.
/// 'isIsolatedNameScope' indicates if the naming scope of this region is
/// isolated from those above.
- ParseResult
- parseRegion(Region ®ion,
- ArrayRef<std::pair<UnresolvedOperand, Type>> entryArguments,
- bool isIsolatedNameScope = false);
+ ParseResult parseRegion(Region ®ion, ArrayRef<Argument> entryArguments,
+ bool isIsolatedNameScope = false);
/// Parse a region body into 'region'.
- ParseResult
- parseRegionBody(Region ®ion, SMLoc startLoc,
- ArrayRef<std::pair<UnresolvedOperand, Type>> entryArguments,
- bool isIsolatedNameScope);
+ ParseResult parseRegionBody(Region ®ion, SMLoc startLoc,
+ ArrayRef<Argument> entryArguments,
+ bool isIsolatedNameScope);
//===--------------------------------------------------------------------===//
// Block Parsing
@@ -947,7 +945,7 @@ ParseResult OperationParser::parseOperation() {
unsigned opResI = 0;
for (ResultRecord &resIt : resultIDs) {
for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
- if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes, {}},
+ if (addDefinition({std::get<2>(resIt), std::get<0>(resIt), subRes},
op->getResult(opResI++)))
return failure();
}
@@ -1279,10 +1277,8 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
if (parser.parseSSAUse(useInfo, allowResultNumber))
return failure();
- result = {useInfo.location, useInfo.name, useInfo.number, {}};
-
- // Parse a source locator on the operand if present.
- return parseOptionalLocationSpecifier(result.sourceLoc);
+ result = {useInfo.location, useInfo.name, useInfo.number};
+ return success();
}
/// Parse a single operand if present.
@@ -1321,11 +1317,7 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
}
auto parseOneOperand = [&]() -> ParseResult {
- UnresolvedOperand operandOrArg;
- if (parseOperand(operandOrArg, allowResultNumber))
- return failure();
- result.push_back(operandOrArg);
- return success();
+ return parseOperand(result.emplace_back(), allowResultNumber);
};
if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list"))
@@ -1402,52 +1394,88 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
return parser.parseAffineExprOfSSAIds(expr, parseElement);
}
+ //===--------------------------------------------------------------------===//
+ // Argument Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a single argument with the following syntax:
+ ///
+ /// `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)`
+ ///
+ /// If `allowType` is false or `allowAttrs` are false then the respective
+ /// parts of the grammar are not parsed.
+ ParseResult parseArgument(Argument &result, bool allowType = false,
+ bool allowAttrs = false) override {
+ NamedAttrList attrs;
+ if (parseOperand(result.ssaName, /*allowResultNumber=*/false) ||
+ (allowType && parseColonType(result.type)) ||
+ (allowAttrs && parseOptionalAttrDict(attrs)) ||
+ parseOptionalLocationSpecifier(result.sourceLoc))
+ return failure();
+ result.attrs = attrs.getDictionary(getContext());
+ return success();
+ }
+
+ /// Parse a single argument if present.
+ OptionalParseResult parseOptionalArgument(Argument &result, bool allowType,
+ bool allowAttrs) override {
+ if (parser.getToken().is(Token::percent_identifier))
+ return parseArgument(result, allowType, allowAttrs);
+ return llvm::None;
+ }
+
+ ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
+ Delimiter delimiter, bool allowType,
+ bool allowAttrs) override {
+ // The no-delimiter case has some special handling for the empty case.
+ if (delimiter == Delimiter::None &&
+ parser.getToken().isNot(Token::percent_identifier))
+ return success();
+
+ auto parseOneArgument = [&]() -> ParseResult {
+ return parseArgument(result.emplace_back(), allowType, allowAttrs);
+ };
+ return parseCommaSeparatedList(delimiter, parseOneArgument,
+ " in argument list");
+ }
+
//===--------------------------------------------------------------------===//
// Region Parsing
//===--------------------------------------------------------------------===//
/// Parse a region that takes `arguments` of `argTypes` types. This
/// effectively defines the SSA values of `arguments` and assigns their type.
- ParseResult parseRegion(Region ®ion, ArrayRef<UnresolvedOperand> arguments,
- ArrayRef<Type> argTypes,
+ ParseResult parseRegion(Region ®ion, ArrayRef<Argument> arguments,
bool enableNameShadowing) override {
- assert(arguments.size() == argTypes.size() &&
- "mismatching number of arguments and types");
-
- SmallVector<std::pair<OperationParser::UnresolvedOperand, Type>, 2>
- regionArguments;
- for (auto pair : llvm::zip(arguments, argTypes))
- regionArguments.emplace_back(std::get<0>(pair), std::get<1>(pair));
-
// Try to parse the region.
(void)isIsolatedFromAbove;
assert((!enableNameShadowing || isIsolatedFromAbove) &&
"name shadowing is only allowed on isolated regions");
- if (parser.parseRegion(region, regionArguments, enableNameShadowing))
+ if (parser.parseRegion(region, arguments, enableNameShadowing))
return failure();
return success();
}
/// Parses a region if present.
OptionalParseResult parseOptionalRegion(Region ®ion,
- ArrayRef<UnresolvedOperand> arguments,
- ArrayRef<Type> argTypes,
+ ArrayRef<Argument> arguments,
bool enableNameShadowing) override {
if (parser.getToken().isNot(Token::l_brace))
return llvm::None;
- return parseRegion(region, arguments, argTypes, enableNameShadowing);
+ return parseRegion(region, arguments, enableNameShadowing);
}
/// Parses a region if present. If the region is present, a new region is
/// allocated and placed in `region`. If no region is present, `region`
/// remains untouched.
- OptionalParseResult parseOptionalRegion(
- std::unique_ptr<Region> ®ion, ArrayRef<UnresolvedOperand> arguments,
- ArrayRef<Type> argTypes, bool enableNameShadowing = false) override {
+ OptionalParseResult
+ parseOptionalRegion(std::unique_ptr<Region> ®ion,
+ ArrayRef<Argument> arguments,
+ bool enableNameShadowing = false) override {
if (parser.getToken().isNot(Token::l_brace))
return llvm::None;
std::unique_ptr<Region> newRegion = std::make_unique<Region>();
- if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
+ if (parseRegion(*newRegion, arguments, enableNameShadowing))
return failure();
region = std::move(newRegion);
@@ -1492,42 +1520,15 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...).
OptionalParseResult parseOptionalAssignmentList(
- SmallVectorImpl<UnresolvedOperand> &lhs,
+ SmallVectorImpl<Argument> &lhs,
SmallVectorImpl<UnresolvedOperand> &rhs) override {
if (failed(parseOptionalLParen()))
return llvm::None;
auto parseElt = [&]() -> ParseResult {
- UnresolvedOperand regionArg, operand;
- if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
- parseEqual() || parseOperand(operand))
- return failure();
- lhs.push_back(regionArg);
- rhs.push_back(operand);
- return success();
- };
- return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
- }
-
- /// Parse a list of assignments of the form
- /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
- OptionalParseResult
- parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs,
- SmallVectorImpl<UnresolvedOperand> &rhs,
- SmallVectorImpl<Type> &types) override {
- if (failed(parseOptionalLParen()))
- return llvm::None;
-
- auto parseElt = [&]() -> ParseResult {
- UnresolvedOperand regionArg, operand;
- Type type;
- if (parseOperand(regionArg, /*allowResultNumber=*/false) ||
- parseEqual() || parseOperand(operand) || parseColon() ||
- parseType(type))
+ if (parseArgument(lhs.emplace_back()) || parseEqual() ||
+ parseOperand(rhs.emplace_back()))
return failure();
- lhs.push_back(regionArg);
- rhs.push_back(operand);
- types.push_back(type);
return success();
};
return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
@@ -1749,11 +1750,9 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
// Region Parsing
//===----------------------------------------------------------------------===//
-ParseResult OperationParser::parseRegion(
- Region ®ion,
- ArrayRef<std::pair<OperationParser::UnresolvedOperand, Type>>
- entryArguments,
- bool isIsolatedNameScope) {
+ParseResult OperationParser::parseRegion(Region ®ion,
+ ArrayRef<Argument> entryArguments,
+ bool isIsolatedNameScope) {
// Parse the '{'.
Token lBraceTok = getToken();
if (parseToken(Token::l_brace, "expected '{' to begin a region"))
@@ -1778,11 +1777,9 @@ ParseResult OperationParser::parseRegion(
return success();
}
-ParseResult OperationParser::parseRegionBody(
- Region ®ion, SMLoc startLoc,
- ArrayRef<std::pair<OperationParser::UnresolvedOperand, Type>>
- entryArguments,
- bool isIsolatedNameScope) {
+ParseResult OperationParser::parseRegionBody(Region ®ion, SMLoc startLoc,
+ ArrayRef<Argument> entryArguments,
+ bool isIsolatedNameScope) {
auto currentPt = opBuilder.saveInsertionPoint();
// Push a new named value scope.
@@ -1798,14 +1795,14 @@ ParseResult OperationParser::parseRegionBody(
if (state.asmState && getToken().isNot(Token::caret_identifier))
state.asmState->addDefinition(block, startLoc);
- // Add arguments to the entry block.
- if (!entryArguments.empty()) {
+ // Add arguments to the entry block if we had the form with explicit names.
+ if (!entryArguments.empty() && !entryArguments[0].ssaName.name.empty()) {
// If we had named arguments, then don't allow a block name.
if (getToken().is(Token::caret_identifier))
return emitError("invalid block name in region with named arguments");
- for (auto &placeholderArgPair : entryArguments) {
- auto &argInfo = placeholderArgPair.first;
+ for (auto &entryArg : entryArguments) {
+ auto &argInfo = entryArg.ssaName;
// Ensure that the argument was not already defined.
if (auto defLoc = getReferenceLoc(argInfo.name, argInfo.number)) {
@@ -1815,10 +1812,10 @@ ParseResult OperationParser::parseRegionBody(
.attachNote(getEncodedSourceLocation(*defLoc))
<< "previously referenced here";
}
- Location loc = argInfo.sourceLoc.hasValue()
- ? argInfo.sourceLoc.getValue()
+ Location loc = entryArg.sourceLoc.hasValue()
+ ? entryArg.sourceLoc.getValue()
: getEncodedSourceLocation(argInfo.location);
- BlockArgument arg = block->addArgument(placeholderArgPair.second, loc);
+ BlockArgument arg = block->addArgument(entryArg.type, loc);
// Add a definition of this arg to the assembly state if provided.
if (state.asmState)
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index fd94c81a05a25..ff9def1e5f191 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -202,7 +202,7 @@ module attributes {gpu.container_module} {
module attributes {gpu.container_module} {
func.func @launch_func_kernel_operand_attr(%sz : index) {
- // expected-error at +1 {{expected arguments without attributes}}
+ // expected-error at +1 {{expected ')' in argument list}}
gpu.launch_func @foo::@bar blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) args(%sz : index {foo})
return
}
diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir
index eb12e8a311747..60be67f035fe1 100644
--- a/mlir/test/IR/locations.mlir
+++ b/mlir/test/IR/locations.mlir
@@ -13,8 +13,9 @@ func.func @inline_notation() -> i32 {
// CHECK: arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
%2 = arith.constant 4 : index loc(callsite("foo" at "mysource.cc":10:8))
+ // CHECK: affine.for %arg0 loc("IVlocation") = 0 to 8 {
// CHECK: } loc(fused["foo", "mysource.cc":10:8])
- affine.for %i0 = 0 to 8 {
+ affine.for %i0 loc("IVlocation") = 0 to 8 {
} loc(fused["foo", "mysource.cc":10:8])
// CHECK: } loc(fused<"myPass">["foo", "foo2"])
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 0978e27634f7c..d319363a71eab 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -691,18 +691,16 @@ static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
OperationState &result) {
- OpAsmParser::UnresolvedOperand argInfo;
- Type argType = parser.getBuilder().getIndexType();
-
// Parse the input operand.
- if (parser.parseOperand(argInfo) ||
- parser.resolveOperand(argInfo, argType, result.operands))
+ OpAsmParser::Argument argInfo;
+ argInfo.type = parser.getBuilder().getIndexType();
+ if (parser.parseOperand(argInfo.ssaName) ||
+ parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
- return parser.parseRegion(*body, argInfo, argType,
- /*enableNameShadowing=*/true);
+ return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
}
void IsolatedRegionOp::print(OpAsmPrinter &p) {
@@ -930,17 +928,16 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
+ SmallVector<OpAsmParser::Argument, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
- if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None,
- /*allowResultNumber=*/false))
+ if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
return failure();
// Parse the body region.
Region *body = result.addRegion();
- auto &builder = parser.getBuilder();
- SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
- return parser.parseRegion(*body, ivsInfo, argTypes);
+ for (auto &iv : ivsInfo)
+ iv.type = parser.getBuilder().getIndexType();
+ return parser.parseRegion(*body, ivsInfo);
}
void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
More information about the Mlir-commits
mailing list