[Mlir-commits] [mlir] 541d89b - [mlir] Fix --convert-func-to-llvm=emit-c-wrappers argument and result attribute handling

Alex Zinenko llvmlistbot at llvm.org
Tue Mar 15 07:29:50 PDT 2022


Author: Sam Carroll
Date: 2022-03-15T15:29:43+01:00
New Revision: 541d89b02c10997477f9109945b1d700d6a78c65

URL: https://github.com/llvm/llvm-project/commit/541d89b02c10997477f9109945b1d700d6a78c65
DIFF: https://github.com/llvm/llvm-project/commit/541d89b02c10997477f9109945b1d700d6a78c65.diff

LOG: [mlir] Fix --convert-func-to-llvm=emit-c-wrappers argument and result attribute handling

When using `--convert-func-to-llvm=emit-c-wrappers` the attribute arguments of the wrapper would not be created correctly in some cases.
This patch fixes that and introduces a set of tests for (hopefully) all corner cases.

See https://github.com/llvm/llvm-project/issues/53503

Author: Sam Carroll <sam.carroll at lmns.com>
Co-Author: Laszlo Kindrat <laszlo.kindrat at lmns.com>

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D119895

Added: 
    mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir
    mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index a2d541a66d45c..4507ddfdfaf3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -26,6 +26,7 @@ def LLVM_Dialect : Dialect {
   let cppNamespace = "::mlir::LLVM";
 
   let hasRegionArgAttrVerify = 1;
+  let hasRegionResultAttrVerify = 1;
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
     /// Name of the data layout attributes.
@@ -38,6 +39,11 @@ def LLVM_Dialect : Dialect {
     static StringRef getParallelAccessAttrName() { return "parallel_access"; }
     static StringRef getLoopOptionsAttrName() { return "options"; }
     static StringRef getAccessGroupsAttrName() { return "access_groups"; }
+    static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
+
+    /// Verifies if the attribute is a well-formed value for "llvm.struct_attrs"
+    static LogicalResult verifyStructAttr(
+        Operation *op, Attribute attr, Type annotatedType);
 
     /// Verifies if the given string is a well-formed data layout descriptor.
     /// Uses `reportError` to report errors.

diff  --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index a2e8c2ceb9e95..39bebd28f1e58 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -40,6 +40,7 @@
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormatVariadic.h"
+#include <algorithm>
 #include <functional>
 
 using namespace mlir;
@@ -50,19 +51,71 @@ using namespace mlir;
 /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
 /// attributes.
 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
-                                 bool filterArgAttrs,
+                                 bool filterArgAndResAttrs,
                                  SmallVectorImpl<NamedAttribute> &result) {
   for (const auto &attr : attrs) {
     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
         attr.getName() == "func.varargs" ||
-        (filterArgAttrs &&
-         attr.getName() == FunctionOpInterface::getArgDictAttrName()))
+        (filterArgAndResAttrs &&
+         (attr.getName() == FunctionOpInterface::getArgDictAttrName() ||
+          attr.getName() == FunctionOpInterface::getResultDictAttrName())))
       continue;
     result.push_back(attr);
   }
 }
 
+/// Helper function for wrapping all attributes into a single DictionaryAttr
+static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
+  return DictionaryAttr::get(
+      b.getContext(),
+      b.getNamedAttr(LLVM::LLVMDialect::getStructAttrsAttrName(), attrs));
+}
+
+/// Combines all result attributes into a single DictionaryAttr
+/// and prepends to argument attrs.
+/// This is intended to be used to format the attributes for a C wrapper
+/// function when the result(s) is converted to the first function argument
+/// (in the multiple return case, all returns get wrapped into a single
+/// argument). The total number of argument attributes should be equal to
+/// (number of function arguments) + 1.
+static void
+prependResAttrsToArgAttrs(OpBuilder &builder,
+                          SmallVectorImpl<NamedAttribute> &attributes,
+                          size_t numArguments) {
+  auto allAttrs = SmallVector<Attribute>(
+      numArguments + 1, DictionaryAttr::get(builder.getContext()));
+  NamedAttribute *argAttrs = nullptr;
+  for (auto it = attributes.begin(); it != attributes.end();) {
+    if (it->getName() == FunctionOpInterface::getArgDictAttrName()) {
+      auto arrayAttrs = it->getValue().cast<ArrayAttr>();
+      assert(arrayAttrs.size() == numArguments &&
+             "Number of arg attrs and args should match");
+      std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1);
+      argAttrs = it;
+    } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) {
+      auto arrayAttrs = it->getValue().cast<ArrayAttr>();
+      assert(!arrayAttrs.empty() && "expected array to be non-empty");
+      allAttrs[0] = (arrayAttrs.size() == 1)
+                        ? arrayAttrs[0]
+                        : wrapAsStructAttrs(builder, arrayAttrs);
+      it = attributes.erase(it);
+      continue;
+    }
+    it++;
+  }
+
+  auto newArgAttrs =
+      builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(),
+                           builder.getArrayAttr(allAttrs));
+  if (!argAttrs) {
+    attributes.emplace_back(newArgAttrs);
+    return;
+  }
+  *argAttrs = newArgAttrs;
+  return;
+}
+
 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
 /// arguments instead of unpacked arguments. This function can be called from C
 /// by passing a pointer to a C struct corresponding to a memref descriptor.
@@ -76,12 +129,14 @@ static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
                                    FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
   auto type = funcOp.getType();
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
                        attributes);
   Type wrapperFuncType;
   bool resultIsNowArg;
   std::tie(wrapperFuncType, resultIsNowArg) =
       typeConverter.convertFunctionTypeCWrapper(type);
+  if (resultIsNowArg)
+    prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments());
   auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
       wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
@@ -142,9 +197,11 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   assert(wrapperType && "unexpected type conversion failure");
 
   SmallVector<NamedAttribute, 4> attributes;
-  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/false,
+  filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false,
                        attributes);
 
+  if (resultIsNowArg)
+    prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments());
   // Create the auxiliary function.
   auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
       loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
@@ -235,11 +292,21 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
     if (!llvmType)
       return nullptr;
 
-    // Propagate argument attributes to all converted arguments obtained after
-    // converting a given original argument.
+    // Propagate argument/result attributes to all converted arguments/result
+    // obtained after converting a given original argument/result.
     SmallVector<NamedAttribute, 4> attributes;
-    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
+    filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
                          attributes);
+    if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
+      assert(!resAttrDicts.empty() && "expected array to be non-empty");
+      auto newResAttrDicts =
+          (funcOp.getNumResults() == 1)
+              ? resAttrDicts
+              : rewriter.getArrayAttr(
+                    {wrapAsStructAttrs(rewriter, resAttrDicts)});
+      attributes.push_back(rewriter.getNamedAttr(
+          FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
+    }
     if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
       SmallVector<Attribute, 4> newArgAttrs(
           llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index a816fb83fa2d0..c6996c6108623 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2605,6 +2605,12 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
              << "' to be a `loopopts` attribute";
   }
 
+  if (attr.getName() == LLVMDialect::getStructAttrsAttrName()) {
+    return op->emitOpError()
+           << "'" << LLVM::LLVMDialect::getStructAttrsAttrName()
+           << "' is permitted only in argument or result attributes";
+  }
+
   // If the data layout attribute is present, it must use the LLVM data layout
   // syntax. Try parsing it and report errors in case of failure. Users of this
   // attribute may assume it is well-formed and can pass it to the (asserting)
@@ -2621,6 +2627,46 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
                            << "' to be a string attribute";
 }
 
+LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
+                                            Type annotatedType) {
+  auto structType = annotatedType.dyn_cast<LLVMStructType>();
+  if (!structType) {
+    const auto emitIncorrectAnnotatedType = [&op]() {
+      return op->emitError()
+             << "expected '" << LLVMDialect::getStructAttrsAttrName()
+             << "' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'";
+    };
+    const auto ptrType = annotatedType.dyn_cast<LLVMPointerType>();
+    if (!ptrType)
+      return emitIncorrectAnnotatedType();
+    structType = ptrType.getElementType().dyn_cast<LLVMStructType>();
+    if (!structType)
+      return emitIncorrectAnnotatedType();
+  }
+
+  const auto arrAttrs = attr.dyn_cast<ArrayAttr>();
+  if (!arrAttrs)
+    return op->emitError() << "expected '"
+                           << LLVMDialect::getStructAttrsAttrName()
+                           << "' to be an array attribute";
+
+  if (structType.getBody().size() != arrAttrs.size())
+    return op->emitError()
+           << "size of '" << LLVMDialect::getStructAttrsAttrName()
+           << "' must match the size of the annotated '!llvm.struct'";
+  return success();
+}
+
+static LogicalResult verifyFuncOpInterfaceStructAttr(
+    Operation *op, Attribute attr,
+    std::function<Type(FunctionOpInterface)> getAnnotatedType) {
+  if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
+    return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
+  return op->emitError() << "expected '"
+                         << LLVMDialect::getStructAttrsAttrName()
+                         << "' to be used on function-like operations";
+}
+
 /// Verify LLVMIR function argument attributes.
 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
                                                     unsigned regionIdx,
@@ -2636,6 +2682,25 @@ LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
       !argAttr.getValue().isa<IntegerAttr>())
     return op->emitError()
            << "llvm.align argument attribute of non integer type";
+  if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
+    return verifyFuncOpInterfaceStructAttr(
+        op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
+          return funcOp.getArgumentTypes()[argIdx];
+        });
+  }
+  return success();
+}
+
+LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
+                                                       unsigned regionIdx,
+                                                       unsigned resIdx,
+                                                       NamedAttribute resAttr) {
+  if (resAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
+    return verifyFuncOpInterfaceStructAttr(
+        op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
+          return funcOp.getResultTypes()[resIdx];
+        });
+  }
   return success();
 }
 

diff  --git a/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir b/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir
new file mode 100644
index 0000000000000..d18bb6f692ed5
--- /dev/null
+++ b/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-callers.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -convert-func-to-llvm='emit-c-wrappers=1' %s | FileCheck %s
+
+// CHECK: llvm.func @res_attrs_with_memref_return() -> (!llvm.struct{{.*}} {test.returnOne})
+// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_memref_return
+// CHECK: %{{.*}}: !llvm.ptr{{.*}} {test.returnOne}
+func @res_attrs_with_memref_return() -> (memref<f32> {test.returnOne}) {
+  %0 = memref.alloc() : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK: llvm.func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_value_return
+// CHECK: -> (f32 {test.returnOne = 1 : i64})
+func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1}) {
+  %0 = arith.constant 1.00 : f32
+  return %0 : f32
+}
+
+// CHECK: llvm.func @multiple_return() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return
+// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+func @multiple_return() -> (memref<f32> {test.returnOne = 1}, f32 {test.returnTwo = 2, test.returnThree = 3}) {
+  %0 = memref.alloc() : memref<f32>
+  %1 = arith.constant 1.00 : f32
+  return %0, %1 : memref<f32>, f32
+}
+
+// CHECK: llvm.func @multiple_return_missing_res_attr() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return_missing_res_attr
+// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+func @multiple_return_missing_res_attr() -> (memref<f32> {test.returnOne = 1}, i64, f32 {test.returnTwo = 2, test.returnThree = 3}) {
+  %0 = memref.alloc() : memref<f32>
+  %1 = arith.constant 2 : i64
+  %2 = arith.constant 1.00 : f32
+  return %0, %1, %2 : memref<f32>, i64, f32
+}
+
+// CHECK: llvm.func @one_arg_attr_no_res_attrs_with_memref_return({{.*}}) -> !llvm.struct{{.*}}
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_no_res_attrs_with_memref_return
+// CHECK: %{{.*}}: !llvm.ptr<{{.*}}>, %{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}
+func @one_arg_attr_no_res_attrs_with_memref_return(%arg0: memref<f32> {test.argOne = 1}) -> memref<f32> {
+  %0 = memref.alloc() : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK: llvm.func @one_arg_attr_one_res_attr_with_memref_return({{.*}}) -> (!llvm.struct<{{.*}}> {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_memref_return
+// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {test.returnOne = 1 : i64}, %{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}
+func @one_arg_attr_one_res_attr_with_memref_return(%arg0: memref<f32> {test.argOne = 1}) -> (memref<f32> {test.returnOne = 1}) {
+  %0 = memref.alloc() : memref<f32>
+  return %0 : memref<f32>
+}
+
+// CHECK: llvm.func @one_arg_attr_one_res_attr_with_value_return({{.*}}) -> (f32 {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_value_return
+// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}) -> (f32 {test.returnOne = 1 : i64})
+func @one_arg_attr_one_res_attr_with_value_return(%arg0: memref<f32> {test.argOne = 1}) -> (f32 {test.returnOne = 1}) {
+  %0 = arith.constant 1.00 : f32
+  return %0 : f32
+}
+
+// CHECK: llvm.func @multiple_arg_attr_multiple_res_attr({{.*}}) -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_arg_attr_multiple_res_attr
+// CHECK: (%{{.*}}: !llvm.ptr<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}, %{{.*}}: !llvm.ptr<{{.*}}> {test.argZero = 0 : i64}, %{{.*}}: f32, %{{.*}}: i32 {test.argTwo = 2 : i64}
+func @multiple_arg_attr_multiple_res_attr(%arg0: memref<f32> {test.argZero = 0}, %arg1: f32, %arg2: i32 {test.argTwo = 2}) -> (f32, memref<i32> {test.returnOne = 1}, i32 {test.returnTwo = 2}) {
+  %0 = arith.constant 1.00 : f32
+  %1 = memref.alloc() : memref<i32>
+  %2 = arith.constant 2 : i32
+  return %0, %1, %2 : f32, memref<i32>, i32
+}

diff  --git a/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir b/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir
new file mode 100644
index 0000000000000..ca136a8993d42
--- /dev/null
+++ b/mlir/test/Conversion/StandardToLLVM/emit-c-wrappers-for-external-functions.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -convert-func-to-llvm='emit-c-wrappers=1' %s | FileCheck %s
+
+// CHECK: llvm.func @res_attrs_with_memref_return() -> (!llvm.struct{{.*}} {test.returnOne})
+// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_memref_return
+// CHECK: !llvm.ptr{{.*}} {test.returnOne}
+func private @res_attrs_with_memref_return() -> (memref<f32> {test.returnOne})
+
+// CHECK: llvm.func @res_attrs_with_value_return() -> (f32 {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_res_attrs_with_value_return
+// CHECK: -> (f32 {test.returnOne = 1 : i64})
+func private @res_attrs_with_value_return() -> (f32 {test.returnOne = 1})
+
+// CHECK: llvm.func @multiple_return() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return
+// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+func private @multiple_return() -> (memref<f32> {test.returnOne = 1}, f32 {test.returnTwo = 2, test.returnThree = 3})
+
+// CHECK: llvm.func @multiple_return_missing_res_attr() -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_return_missing_res_attr
+// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{test.returnOne = 1 : i64}, {}, {test.returnThree = 3 : i64, test.returnTwo = 2 : i64}]})
+func private @multiple_return_missing_res_attr() -> (memref<f32> {test.returnOne = 1}, i64, f32 {test.returnTwo = 2, test.returnThree = 3})
+
+// CHECK: llvm.func @one_arg_attr_no_res_attrs_with_memref_return({{.*}}) -> !llvm.struct{{.*}}
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_no_res_attrs_with_memref_return
+// CHECK: !llvm.ptr<{{.*}}>, !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}
+func private @one_arg_attr_no_res_attrs_with_memref_return(%arg0: memref<f32> {test.argOne = 1}) -> memref<f32>
+
+// CHECK: llvm.func @one_arg_attr_one_res_attr_with_memref_return({{.*}}) -> (!llvm.struct<{{.*}}> {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_memref_return
+// CHECK: (!llvm.ptr<{{.*}}> {test.returnOne = 1 : i64}, !llvm.ptr<{{.*}}> {test.argOne = 1 : i64}
+func private @one_arg_attr_one_res_attr_with_memref_return(%arg0: memref<f32> {test.argOne = 1}) -> (memref<f32> {test.returnOne = 1})
+
+// CHECK: llvm.func @one_arg_attr_one_res_attr_with_value_return({{.*}}) -> (f32 {test.returnOne = 1 : i64})
+// CHECK-LABEL: llvm.func @_mlir_ciface_one_arg_attr_one_res_attr_with_value_return
+// CHECK: (!llvm.ptr<{{.*}}> {test.argOne = 1 : i64}) -> (f32 {test.returnOne = 1 : i64})
+func private @one_arg_attr_one_res_attr_with_value_return(%arg0: memref<f32> {test.argOne = 1}) -> (f32 {test.returnOne = 1})
+
+// CHECK: llvm.func @multiple_arg_attr_multiple_res_attr({{.*}}) -> (!llvm.struct<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]})
+// CHECK-LABEL: llvm.func @_mlir_ciface_multiple_arg_attr_multiple_res_attr
+// CHECK: (!llvm.ptr<{{.*}}> {llvm.struct_attrs = [{}, {test.returnOne = 1 : i64}, {test.returnTwo = 2 : i64}]}, !llvm.ptr<{{.*}}> {test.argZero = 0 : i64}, f32, i32 {test.argTwo = 2 : i64}
+func private @multiple_arg_attr_multiple_res_attr(%arg0: memref<f32> {test.argZero = 0}, %arg1: f32, %arg2: i32 {test.argTwo = 2}) -> (f32, memref<i32> {test.returnOne = 1}, i32 {test.returnTwo = 2})

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index cdee3931f1a7a..cd8314838cde4 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -123,6 +123,20 @@ module {
   // CHECK: llvm.func @external_func
   // GENERIC: linkage = #llvm.linkage<external>
   llvm.func external @external_func()
+
+  // CHECK-LABEL: llvm.func @arg_struct_attr(
+  // CHECK-SAME: %{{.*}}: !llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
+  llvm.func @arg_struct_attr(
+      %arg0 : !llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
+    llvm.return
+  }
+
+   // CHECK-LABEL: llvm.func @res_struct_attr(%{{.*}}: !llvm.struct<(i32)>)
+   // CHECK-SAME:-> (!llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
+  llvm.func @res_struct_attr(%arg0 : !llvm.struct<(i32)>)
+      -> (!llvm.struct<(i32)> {llvm.struct_attrs = [{llvm.noalias}]}) {
+    llvm.return %arg0 : !llvm.struct<(i32)>
+  }
 }
 
 // -----

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 94c3446821edc..6c7e3ae5712d7 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1300,3 +1300,66 @@ llvm.mlir.global internal @side_effecting_global() : !llvm.struct<(i8)> {
   %2 = llvm.load %1 : !llvm.ptr<struct<(i8)>>
   llvm.return %2 : !llvm.struct<(i8)>
 }
+
+// -----
+
+// expected-error at +1 {{'llvm.struct_attrs' is permitted only in argument or result attributes}}
+func @struct_attrs_in_op() attributes {llvm.struct_attrs = []} {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'}}
+func @invalid_struct_attr_arg_type(%arg0 : i32 {llvm.struct_attrs = []}) {
+    return
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'}}
+func @invalid_struct_attr_pointer_arg_type(%arg0 : !llvm.ptr<i32> {llvm.struct_attrs = []}) {
+    return
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to be an array attribute}}
+func @invalid_arg_struct_attr_value(%arg0 : !llvm.struct<(i32)> {llvm.struct_attrs = {}}) {
+    return
+}
+
+// -----
+
+// expected-error at +1 {{size of 'llvm.struct_attrs' must match the size of the annotated '!llvm.struct'}}
+func @invalid_arg_struct_attr_size(%arg0 : !llvm.struct<(i32)> {llvm.struct_attrs = []}) {
+    return
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'}}
+func @invalid_struct_attr_res_type(%arg0 : i32) -> (i32 {llvm.struct_attrs = []}) {
+  return %arg0 : i32
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to annotate '!llvm.struct' or '!llvm.ptr<struct<...>>'}}
+func @invalid_struct_attr_pointer_res_type(%arg0 : !llvm.ptr<i32>) -> (!llvm.ptr<i32> {llvm.struct_attrs = []}) {
+    return %arg0 : !llvm.ptr<i32>
+}
+
+// -----
+
+// expected-error at +1 {{expected 'llvm.struct_attrs' to be an array attribute}}
+func @invalid_res_struct_attr_value(%arg0 : !llvm.struct<(i32)>) -> (!llvm.struct<(i32)> {llvm.struct_attrs = {}}) {
+    return %arg0 : !llvm.struct<(i32)>
+}
+
+// -----
+
+// expected-error at +1 {{size of 'llvm.struct_attrs' must match the size of the annotated '!llvm.struct'}}
+func @invalid_res_struct_attr_size(%arg0 : !llvm.struct<(i32)>) -> (!llvm.struct<(i32)> {llvm.struct_attrs = []}) {
+    return %arg0 : !llvm.struct<(i32)>
+}


        


More information about the Mlir-commits mailing list