[flang-commits] [flang] [mlir] [mlir] share argument attributes interface between calls and callables (PR #123176)
Tobias Gysi via flang-commits
flang-commits at lists.llvm.org
Sun Jan 26 23:17:38 PST 2025
================
@@ -7,9 +7,178 @@
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/IR/Builders.h"
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// Argument and result attributes utilities
+//===----------------------------------------------------------------------===//
+
+static ParseResult
+parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
+ SmallVectorImpl<DictionaryAttr> &attrs) {
+ // Parse individual function results.
+ return parser.parseCommaSeparatedList([&]() -> ParseResult {
+ types.emplace_back();
+ attrs.emplace_back();
+ NamedAttrList attrList;
+ if (parser.parseType(types.back()) ||
+ parser.parseOptionalAttrDict(attrList))
+ return failure();
+ attrs.back() = attrList.getDictionary(parser.getContext());
+ return success();
+ });
+}
+
+ParseResult call_interface_impl::parseFunctionResultList(
+ OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
+ if (failed(parser.parseOptionalLParen())) {
+ // We already know that there is no `(`, so parse a type.
+ // Because there is no `(`, it cannot be a function type.
+ Type ty;
+ if (parser.parseType(ty))
+ return failure();
+ resultTypes.push_back(ty);
+ resultAttrs.emplace_back();
+ return success();
+ }
+
+ // Special case for an empty set of parens.
+ if (succeeded(parser.parseOptionalRParen()))
+ return success();
+ if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
+ return failure();
+ return parser.parseRParen();
+}
+
+ParseResult call_interface_impl::parseFunctionSignature(
+ OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
+ // Parse arguments.
+ if (parser.parseLParen())
+ return failure();
+ if (failed(parser.parseOptionalRParen())) {
+ if (parseTypeAndAttrList(parser, argTypes, argAttrs))
+ return failure();
+ if (parser.parseRParen())
+ return failure();
+ }
+ // Parse results.
+ if (succeeded(parser.parseOptionalArrow()))
+ return call_interface_impl::parseFunctionResultList(parser, resultTypes,
+ resultAttrs);
+ if (mustParseEmptyResult)
+ return failure();
+ return success();
+}
+
+/// Print a function result list. The provided `attrs` must either be null, or
+/// contain a set of DictionaryAttrs of the same arity as `types`.
+static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
+ ArrayAttr attrs) {
+ assert(!types.empty() && "Should not be called for empty result list.");
+ assert((!attrs || attrs.size() == types.size()) &&
+ "Invalid number of attributes.");
+
+ auto &os = p.getStream();
+ bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
+ (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
+ if (needsParens)
+ os << '(';
+ llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
+ p.printType(types[i]);
+ if (attrs)
+ p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
+ });
+ if (needsParens)
+ os << ')';
+}
+
+void call_interface_impl::printFunctionSignature(
+ OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
+ TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
+ bool printEmptyResult) {
+ bool isExternal = !body || body->empty();
+ if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
+ printEmptyResult) {
+ p.printFunctionalType(argTypes, resultTypes);
+ return;
+ }
+
+ p << '(';
+ for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
+ if (i > 0)
+ p << ", ";
+
+ if (!isExternal) {
+ ArrayRef<NamedAttribute> attrs;
+ if (argAttrs)
+ attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
+ p.printRegionArgument(body->getArgument(i), attrs);
+ } else {
+ p.printType(argTypes[i]);
+ if (argAttrs)
+ p.printOptionalAttrDict(
+ llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
+ }
+ }
+
+ if (isVariadic) {
+ if (!argTypes.empty())
+ p << ", ";
+ p << "...";
+ }
+
+ p << ')';
+
+ if (!resultTypes.empty()) {
+ p << " -> ";
+ printFunctionResultList(p, resultTypes, resultAttrs);
+ } else if (printEmptyResult) {
+ p << " -> ()";
+ }
+}
+
+void call_interface_impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
+ ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
+ StringAttr resAttrsName) {
+ auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
+ return attrs && !attrs.empty();
+ };
+ // Convert the specified array of dictionary attrs (which may have null
+ // entries) to an ArrayAttr of dictionaries.
+ auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
+ SmallVector<Attribute> attrs;
+ for (auto &dict : dictAttrs)
+ attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
+ return builder.getArrayAttr(attrs);
+ };
+
+ // Add the attributes to the function arguments.
+ if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
+ result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
+
+ // Add the attributes to the function results.
+ if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
+ result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
----------------
gysit wrote:
```suggestion
// Add the attributes to the operation arguments.
if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
// Add the attributes to the operation results.
if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
```
ultra nit: IUC this will be used for calls and functions. I would thus just use operation.
https://github.com/llvm/llvm-project/pull/123176
More information about the flang-commits
mailing list