[Mlir-commits] [mlir] be4b873 - [mlir][LLVM] Add param attr verifiers

Christian Ulmann llvmlistbot at llvm.org
Wed Jan 25 06:05:12 PST 2023


Author: Christian Ulmann
Date: 2023-01-25T15:04:45+01:00
New Revision: be4b87353e4230e8dab4173d5200e8d88f6ab032

URL: https://github.com/llvm/llvm-project/commit/be4b87353e4230e8dab4173d5200e8d88f6ab032
DIFF: https://github.com/llvm/llvm-project/commit/be4b87353e4230e8dab4173d5200e8d88f6ab032.diff

LOG: [mlir][LLVM] Add param attr verifiers

This commit introduces unified parameter attribute verifiers to the LLVM
dialect and removes according checks in the export. As LLVM does not
verify the validity of certain attributes on return values, this commit
unifies the handling of argument and result attributes wherever possible.

Depends on D142212

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D142372

Added: 
    mlir/test/Dialect/LLVMIR/parameter-attrs-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/LLVMIR/func.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index eda30e36086fe..9166ae78bd113 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -54,6 +54,7 @@ def LLVM_Dialect : Dialect {
     static StringRef getSExtAttrName() { return "llvm.signext"; }
     static StringRef getZExtAttrName() { return "llvm.zeroext"; }
     static StringRef getTBAAAttrName() { return "llvm.tbaa"; }
+    static StringRef getNestAttrName() { return "llvm.nest"; }
 
     /// Verifies if the attribute is a well-formed value for "llvm.struct_attrs"
     static LogicalResult verifyStructAttr(
@@ -87,6 +88,12 @@ def LLVM_Dialect : Dialect {
     void printType(Type, DialectAsmPrinter &p) const override;
 
   private:
+    /// Verifies a parameter attribute attached to a parameter of type
+    /// paramType.
+    LogicalResult verifyParameterAttribute(Operation *op,
+                                           Type paramType,
+                                           NamedAttribute paramAttr);
+
     /// Register all types.
     void registerTypes();
 

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 33f8ab7d066df..4e85dcf6ab348 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -15,8 +15,8 @@
 #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
 
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Target/LLVMIR/Export.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
 #include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -273,7 +273,7 @@ class ModuleTranslation {
     ModuleTranslation &moduleTranslation;
   };
 
-  SymbolTableCollection& symbolTable() { return symbolTableCollection; }
+  SymbolTableCollection &symbolTable() { return symbolTableCollection; }
 
 private:
   ModuleTranslation(Operation *module,
@@ -306,6 +306,9 @@ class ModuleTranslation {
   /// Translates dialect attributes attached to the given operation.
   LogicalResult convertDialectAttributes(Operation *op);
 
+  /// Translates parameter attributes and adds them to the returned AttrBuilder.
+  llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+
   /// Original and translated module.
   Operation *mlirModule;
   std::unique_ptr<llvm::Module> llvmModule;

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 4f301ff00c953..a381efc189a6f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -3120,123 +3120,157 @@ LogicalResult LLVMDialect::verifyStructAttr(Operation *op, Attribute attr,
   return success();
 }
 
-static LogicalResult verifyFuncOpInterfaceStructAttr(
-    Operation *op, Attribute attr,
-    const std::function<Type(FunctionOpInterface)> &getAnnotatedType) {
-  if (auto funcOp = dyn_cast<FunctionOpInterface>(op))
-    return LLVMDialect::verifyStructAttr(op, attr, getAnnotatedType(funcOp));
+static LogicalResult verifyFuncOpInterfaceStructAttr(Operation *op,
+                                                     Attribute attr,
+                                                     Type annotatedType) {
+  if (isa<FunctionOpInterface>(op))
+    return LLVMDialect::verifyStructAttr(op, attr, annotatedType);
   return op->emitError() << "expected '"
                          << LLVMDialect::getStructAttrsAttrName()
                          << "' to be used on function-like operations";
 }
 
+LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
+                                                    Type paramType,
+                                                    NamedAttribute paramAttr) {
+  // LLVM attribute may be attached to a result of operation that has not been
+  // converted to LLVM dialect yet, so the result may have a type with unknown
+  // representation in LLVM dialect type space. In this case we cannot verify
+  // whether the attribute may be
+  bool verifyValueType = isCompatibleType(paramType);
+  StringAttr name = paramAttr.getName();
+
+  auto checkUnitAttrType = [&]() -> LogicalResult {
+    if (!paramAttr.getValue().isa<UnitAttr>())
+      return op->emitError() << name << " should be a unit attribute";
+    return success();
+  };
+  auto checkTypeAttrType = [&]() -> LogicalResult {
+    if (!paramAttr.getValue().isa<TypeAttr>())
+      return op->emitError() << name << " should be a type attribute";
+    return success();
+  };
+  auto checkIntegerAttrType = [&]() -> LogicalResult {
+    if (!paramAttr.getValue().isa<IntegerAttr>())
+      return op->emitError() << name << " should be an integer attribute";
+    return success();
+  };
+  auto checkPointerType = [&]() -> LogicalResult {
+    if (!paramType.isa<LLVMPointerType>())
+      return op->emitError()
+             << name << " attribute attached to non-pointer LLVM type";
+    return success();
+  };
+  auto checkIntegerType = [&]() -> LogicalResult {
+    if (!paramType.isa<IntegerType>())
+      return op->emitError()
+             << name << " attribute attached to non-integer LLVM type";
+    return success();
+  };
+  auto checkPointerTypeMatches = [&]() -> LogicalResult {
+    if (failed(checkPointerType()))
+      return failure();
+    auto ptrType = paramType.cast<LLVMPointerType>();
+    auto typeAttr = paramAttr.getValue().cast<TypeAttr>();
+
+    if (!ptrType.isOpaque() && ptrType.getElementType() != typeAttr.getValue())
+      return op->emitError()
+             << name
+             << " attribute attached to LLVM pointer argument of "
+                "
diff erent type";
+    return success();
+  };
+
+  // Note: The struct parameter attributes are not lowered to LLVM IR.
+  if (name == LLVMDialect::getStructAttrsAttrName())
+    return verifyFuncOpInterfaceStructAttr(op, paramAttr.getValue(), paramType);
+
+  // Check a unit attribute that is attached to a pointer value.
+  if (name == LLVMDialect::getNoAliasAttrName() ||
+      name == LLVMDialect::getReadonlyAttrName() ||
+      name == LLVMDialect::getNestAttrName()) {
+    if (failed(checkUnitAttrType()))
+      return failure();
+    if (verifyValueType && failed(checkPointerType()))
+      return failure();
+    return success();
+  }
+
+  // Check a type attribute that is attached to a pointer value.
+  if (name == LLVMDialect::getStructRetAttrName() ||
+      name == LLVMDialect::getByValAttrName() ||
+      name == LLVMDialect::getByRefAttrName() ||
+      name == LLVMDialect::getInAllocaAttrName()) {
+    if (failed(checkTypeAttrType()))
+      return failure();
+    if (verifyValueType && failed(checkPointerTypeMatches()))
+      return failure();
+    return success();
+  }
+
+  // Check a unit attribute that is attached to an integer value.
+  if (name == LLVMDialect::getSExtAttrName() ||
+      name == LLVMDialect::getZExtAttrName()) {
+    if (failed(checkUnitAttrType()))
+      return failure();
+    if (verifyValueType && failed(checkIntegerType()))
+      return failure();
+    return success();
+  }
+
+  // Check an integer attribute that is attached to a pointer value.
+  if (name == LLVMDialect::getAlignAttrName()) {
+    if (failed(checkIntegerAttrType()))
+      return failure();
+    if (verifyValueType && failed(checkPointerType()))
+      return failure();
+    return success();
+  }
+
+  if (name == LLVMDialect::getNoUndefAttrName())
+    return checkUnitAttrType();
+  return success();
+}
+
 /// Verify LLVMIR function argument attributes.
 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
                                                     unsigned regionIdx,
                                                     unsigned argIdx,
                                                     NamedAttribute argAttr) {
-  // Check that llvm.noalias is a unit attribute.
-  if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() &&
-      !argAttr.getValue().isa<UnitAttr>())
-    return op->emitError()
-           << "expected llvm.noalias argument attribute to be a unit attribute";
-  // Check that llvm.align is an integer attribute.
-  if (argAttr.getName() == LLVMDialect::getAlignAttrName() &&
-      !argAttr.getValue().isa<IntegerAttr>())
-    return op->emitError()
-           << "llvm.align argument attribute of non integer type";
-  if (argAttr.getName() == LLVMDialect::getStructAttrsAttrName()) {
-    return verifyFuncOpInterfaceStructAttr(
-        op, argAttr.getValue(), [argIdx](FunctionOpInterface funcOp) {
-          return funcOp.getArgumentTypes()[argIdx];
-        });
-  }
-  return success();
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type argType = funcOp.getArgumentTypes()[argIdx];
+
+  return verifyParameterAttribute(op, argType, argAttr);
 }
 
 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
                                                        unsigned regionIdx,
                                                        unsigned resIdx,
                                                        NamedAttribute resAttr) {
-  StringAttr name = resAttr.getName();
-  if (name == LLVMDialect::getStructAttrsAttrName()) {
-    return verifyFuncOpInterfaceStructAttr(
-        op, resAttr.getValue(), [resIdx](FunctionOpInterface funcOp) {
-          return funcOp.getResultTypes()[resIdx];
-        });
-  }
-  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
-    mlir::Type resTy = funcOp.getResultTypes()[resIdx];
-
-    // Check to see if this function has a void return with a result attribute
-    // to it. It isn't clear what semantics we would assign to that.
-    if (resTy.isa<LLVMVoidType>())
-      return op->emitError() << "cannot attach result attributes to functions "
-                                "with a void return";
-
-    // LLVM attribute may be attached to a result of operation
-    // that has not been converted to LLVM dialect yet, so the result
-    // may have a type with unknown representation in LLVM dialect type
-    // space. In this case we cannot verify whether the attribute may be
-    // attached to a result of such type.
-    bool verifyValueType = isCompatibleType(resTy);
-    Attribute attrValue = resAttr.getValue();
-
-    // TODO: get rid of code duplication here and in verifyRegionArgAttribute().
-    if (name == LLVMDialect::getAlignAttrName()) {
-      if (!attrValue.isa<IntegerAttr>())
-        return op->emitError() << "expected llvm.align result attribute to be "
-                                  "an integer attribute";
-      if (verifyValueType && !resTy.isa<LLVMPointerType>())
-        return op->emitError()
-               << "llvm.align attribute attached to non-pointer result";
-      return success();
-    }
-    if (name == LLVMDialect::getNoAliasAttrName()) {
-      if (!attrValue.isa<UnitAttr>())
-        return op->emitError() << "expected llvm.noalias result attribute to "
-                                  "be a unit attribute";
-      if (verifyValueType && !resTy.isa<LLVMPointerType>())
-        return op->emitError()
-               << "llvm.noalias attribute attached to non-pointer result";
-      return success();
-    }
-    if (name == LLVMDialect::getReadonlyAttrName()) {
-      if (!attrValue.isa<UnitAttr>())
-        return op->emitError() << "expected llvm.readonly result attribute to "
-                                  "be a unit attribute";
-      if (verifyValueType && !resTy.isa<LLVMPointerType>())
-        return op->emitError()
-               << "llvm.readonly attribute attached to non-pointer result";
-      return success();
-    }
-    if (name == LLVMDialect::getNoUndefAttrName()) {
-      if (!attrValue.isa<UnitAttr>())
-        return op->emitError() << "expected llvm.noundef result attribute to "
-                                  "be a unit attribute";
-      return success();
-    }
-    if (name == LLVMDialect::getSExtAttrName()) {
-      if (!attrValue.isa<UnitAttr>())
-        return op->emitError() << "expected llvm.signext result attribute to "
-                                  "be a unit attribute";
-      if (verifyValueType && !resTy.isa<mlir::IntegerType>())
-        return op->emitError()
-               << "llvm.signext attribute attached to non-integer result";
-      return success();
-    }
-    if (name == LLVMDialect::getZExtAttrName()) {
-      if (!attrValue.isa<UnitAttr>())
-        return op->emitError() << "expected llvm.zeroext result attribute to "
-                                  "be a unit attribute";
-      if (verifyValueType && !resTy.isa<mlir::IntegerType>())
-        return op->emitError()
-               << "llvm.zeroext attribute attached to non-integer result";
-      return success();
-    }
-  }
-
-  return success();
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+  Type resType = funcOp.getResultTypes()[resIdx];
+
+  // Check to see if this function has a void return with a result attribute
+  // to it. It isn't clear what semantics we would assign to that.
+  if (resType.isa<LLVMVoidType>())
+    return op->emitError() << "cannot attach result attributes to functions "
+                              "with a void return";
+
+  // Check to see if this attribute is allowed as a result attribute. Only
+  // explicitly forbidden LLVM attributes will cause an error.
+  auto name = resAttr.getName();
+  if (name == LLVMDialect::getReadonlyAttrName() ||
+      name == LLVMDialect::getNestAttrName() ||
+      name == LLVMDialect::getStructRetAttrName() ||
+      name == LLVMDialect::getByValAttrName() ||
+      name == LLVMDialect::getByRefAttrName() ||
+      name == LLVMDialect::getInAllocaAttrName())
+    return op->emitError() << name << " is not a valid result attribute";
+  return verifyParameterAttribute(op, resType, resAttr);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ecb9908c6cb63..b5323de221d7d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -896,6 +896,48 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   llvmFunc->setMemoryEffects(newMemEffects);
 }
 
+llvm::AttrBuilder
+ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+  llvm::AttrBuilder attrBuilder(llvmModule->getContext());
+  if (auto attr = paramAttrs.getAs<UnitAttr>(LLVMDialect::getNoAliasAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::AttrKind::NoAlias);
+
+  if (auto attr =
+          paramAttrs.getAs<UnitAttr>(LLVMDialect::getReadonlyAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::AttrKind::ReadOnly);
+
+  if (auto attr =
+          paramAttrs.getAs<IntegerAttr>(LLVMDialect::getAlignAttrName()))
+    attrBuilder.addAlignmentAttr(llvm::Align(attr.getInt()));
+
+  if (auto attr =
+          paramAttrs.getAs<TypeAttr>(LLVMDialect::getStructRetAttrName()))
+    attrBuilder.addStructRetAttr(convertType(attr.getValue()));
+
+  if (auto attr = paramAttrs.getAs<TypeAttr>(LLVMDialect::getByValAttrName()))
+    attrBuilder.addByValAttr(convertType(attr.getValue()));
+
+  if (auto attr = paramAttrs.getAs<TypeAttr>(LLVMDialect::getByRefAttrName()))
+    attrBuilder.addByRefAttr(convertType(attr.getValue()));
+
+  if (auto attr =
+          paramAttrs.getAs<TypeAttr>(LLVMDialect::getInAllocaAttrName()))
+    attrBuilder.addInAllocaAttr(convertType(attr.getValue()));
+
+  if (auto attr = paramAttrs.getAs<UnitAttr>(LLVMDialect::getNestAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::Nest);
+
+  if (auto attr = paramAttrs.getAs<UnitAttr>(LLVMDialect::getNoUndefAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::NoUndef);
+
+  if (auto attr = paramAttrs.getAs<UnitAttr>(LLVMDialect::getSExtAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::SExt);
+
+  if (auto attr = paramAttrs.getAs<UnitAttr>(LLVMDialect::getZExtAttrName()))
+    attrBuilder.addAttribute(llvm::Attribute::ZExt);
+  return attrBuilder;
+}
+
 LogicalResult ModuleTranslation::convertFunctionSignatures() {
   // Declare all functions first because there may be function calls that form a
   // call graph with cycles, or global initializers that reference functions.
@@ -918,149 +960,16 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
 
     // Convert result attributes.
     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
-      llvm::AttrBuilder retAttrs(llvmFunc->getContext());
       DictionaryAttr resultAttrs = allResultAttrs[0].cast<DictionaryAttr>();
-      for (const NamedAttribute &attr : resultAttrs) {
-        StringAttr name = attr.getName();
-        if (name == LLVMDialect::getAlignAttrName()) {
-          auto alignAmount = attr.getValue().cast<IntegerAttr>();
-          retAttrs.addAlignmentAttr(llvm::Align(alignAmount.getInt()));
-        } else if (name == LLVMDialect::getNoAliasAttrName()) {
-          retAttrs.addAttribute(llvm::Attribute::NoAlias);
-        } else if (name == LLVMDialect::getNoUndefAttrName()) {
-          retAttrs.addAttribute(llvm::Attribute::NoUndef);
-        } else if (name == LLVMDialect::getSExtAttrName()) {
-          retAttrs.addAttribute(llvm::Attribute::SExt);
-        } else if (name == LLVMDialect::getZExtAttrName()) {
-          retAttrs.addAttribute(llvm::Attribute::ZExt);
-        }
-      }
-      llvmFunc->addRetAttrs(retAttrs);
+      llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
     }
 
     // Convert argument attributes.
-    unsigned int argIdx = 0;
-    for (auto [mlirArgTy, llvmArg] :
-         llvm::zip(function.getArgumentTypes(), llvmFunc->args())) {
-      if (auto attr = function.getArgAttrOfType<UnitAttr>(
-              argIdx, LLVMDialect::getNoAliasAttrName())) {
-        // NB: Attribute already verified to be boolean, so check if we can
-        // indeed attach the attribute to this argument, based on its type.
-        if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
-          return function.emitError(
-              "llvm.noalias attribute attached to LLVM non-pointer argument");
-        llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
-      }
-      if (auto attr = function.getArgAttrOfType<UnitAttr>(
-              argIdx, LLVMDialect::getReadonlyAttrName())) {
-        if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
-          return function.emitError(
-              "llvm.readonly attribute attached to LLVM non-pointer argument");
-        llvmArg.addAttr(llvm::Attribute::AttrKind::ReadOnly);
-      }
-
-      if (auto attr = function.getArgAttrOfType<IntegerAttr>(
-              argIdx, LLVMDialect::getAlignAttrName())) {
-        // NB: Attribute already verified to be int, so check if we can indeed
-        // attach the attribute to this argument, based on its type.
-        if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
-          return function.emitError(
-              "llvm.align attribute attached to LLVM non-pointer argument");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addAlignmentAttr(llvm::Align(attr.getInt())));
-      }
-
-      if (auto attr = function.getArgAttrOfType<TypeAttr>(
-              argIdx, LLVMDialect::getStructRetAttrName())) {
-        auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
-        if (!argTy)
-          return function.emitError(
-              "llvm.sret attribute attached to LLVM non-pointer argument");
-        if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
-          return function.emitError(
-              "llvm.sret attribute attached to LLVM pointer "
-              "argument of a 
diff erent type");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addStructRetAttr(convertType(attr.getValue())));
-      }
-
-      if (auto attr = function.getArgAttrOfType<TypeAttr>(
-              argIdx, LLVMDialect::getByValAttrName())) {
-        auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
-        if (!argTy)
-          return function.emitError(
-              "llvm.byval attribute attached to LLVM non-pointer argument");
-        if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
-          return function.emitError(
-              "llvm.byval attribute attached to LLVM pointer "
-              "argument of a 
diff erent type");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addByValAttr(convertType(attr.getValue())));
-      }
-
-      if (auto attr = function.getArgAttrOfType<TypeAttr>(
-              argIdx, LLVMDialect::getByRefAttrName())) {
-        auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
-        if (!argTy)
-          return function.emitError(
-              "llvm.byref attribute attached to LLVM non-pointer argument");
-        if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
-          return function.emitError(
-              "llvm.byref attribute attached to LLVM pointer "
-              "argument of a 
diff erent type");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addByRefAttr(convertType(attr.getValue())));
-      }
-
-      if (auto attr = function.getArgAttrOfType<TypeAttr>(
-              argIdx, LLVMDialect::getInAllocaAttrName())) {
-        auto argTy = mlirArgTy.dyn_cast<LLVM::LLVMPointerType>();
-        if (!argTy)
-          return function.emitError(
-              "llvm.inalloca attribute attached to LLVM non-pointer argument");
-        if (!argTy.isOpaque() && argTy.getElementType() != attr.getValue())
-          return function.emitError(
-              "llvm.inalloca attribute attached to LLVM pointer "
-              "argument of a 
diff erent type");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addInAllocaAttr(convertType(attr.getValue())));
-      }
-
-      if (auto attr =
-              function.getArgAttrOfType<UnitAttr>(argIdx, "llvm.nest")) {
-        if (!mlirArgTy.isa<LLVM::LLVMPointerType>())
-          return function.emitError(
-              "llvm.nest attribute attached to LLVM non-pointer argument");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addAttribute(llvm::Attribute::Nest));
-      }
-
-      if (auto attr = function.getArgAttrOfType<UnitAttr>(
-              argIdx, LLVMDialect::getNoUndefAttrName())) {
-        // llvm.noundef can be added to any argument type.
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addAttribute(llvm::Attribute::NoUndef));
-      }
-      if (auto attr = function.getArgAttrOfType<UnitAttr>(
-              argIdx, LLVMDialect::getSExtAttrName())) {
-        // llvm.signext can be added to any integer argument type.
-        if (!mlirArgTy.isa<mlir::IntegerType>())
-          return function.emitError(
-              "llvm.signext attribute attached to LLVM non-integer argument");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addAttribute(llvm::Attribute::SExt));
+    for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
+      if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
+        llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
+        llvmArg.addAttrs(attrBuilder);
       }
-      if (auto attr = function.getArgAttrOfType<UnitAttr>(
-              argIdx, LLVMDialect::getZExtAttrName())) {
-        // llvm.zeroext can be added to any integer argument type.
-        if (!mlirArgTy.isa<mlir::IntegerType>())
-          return function.emitError(
-              "llvm.zeroext attribute attached to LLVM non-integer argument");
-        llvmArg.addAttrs(llvm::AttrBuilder(llvmArg.getContext())
-                             .addAttribute(llvm::Attribute::ZExt));
-      }
-
-      ++argIdx;
     }
 
     // Forward the pass-through attributes to LLVM.

diff  --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index c4afd42e91bb3..4ad8f445c7844 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -89,8 +89,8 @@ module {
     llvm.return
   }
 
-  // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr<i32> {llvm.byval})
-  llvm.func @byvalattr(%arg0: !llvm.ptr<i32> {llvm.byval}) {
+  // CHECK: llvm.func @byvalattr(%{{.*}}: !llvm.ptr<i32> {llvm.byval = i32})
+  llvm.func @byvalattr(%arg0: !llvm.ptr<i32> {llvm.byval = i32}) {
     llvm.return
   }
 
@@ -267,58 +267,6 @@ module {
 
 // -----
 
-module {
-  // expected-error at +1 {{cannot attach result attributes to functions with a void return}}
-  llvm.func @variadic_def() -> (!llvm.void {llvm.noundef})
-}
-
-// -----
-
-// expected-error @below{{expected llvm.align result attribute to be an integer attribute}}
-llvm.func @alignattr_ret() -> (!llvm.ptr {llvm.align = 1.0 : f32})
-
-// -----
-
-// expected-error @below{{llvm.align attribute attached to non-pointer result}}
-llvm.func @alignattr_ret() -> (i32 {llvm.align = 4})
-
-// -----
-
-// expected-error @below{{expected llvm.noalias result attribute to be a unit attribute}}
-llvm.func @noaliasattr_ret() -> (!llvm.ptr {llvm.noalias = 1})
-
-// -----
-
-// expected-error @below{{llvm.noalias attribute attached to non-pointer result}}
-llvm.func @noaliasattr_ret() -> (i32 {llvm.noalias})
-
-// -----
-
-// expected-error @below{{expected llvm.noundef result attribute to be a unit attribute}}
-llvm.func @noundefattr_ret() -> (!llvm.ptr {llvm.noundef = 1})
-
-// -----
-
-// expected-error @below{{expected llvm.signext result attribute to be a unit attribute}}
-llvm.func @signextattr_ret() -> (i32 {llvm.signext = 1})
-
-// -----
-
-// expected-error @below{{llvm.signext attribute attached to non-integer result}}
-llvm.func @signextattr_ret() -> (f32 {llvm.signext})
-
-// -----
-
-// expected-error @below{{expected llvm.zeroext result attribute to be a unit attribute}}
-llvm.func @zeroextattr_ret() -> (i32 {llvm.zeroext = 1})
-
-// -----
-
-// expected-error @below{{llvm.zeroext attribute attached to non-integer result}}
-llvm.func @zeroextattr_ret() -> (f32 {llvm.zeroext})
-
-// -----
-
 module {
   // expected-error at +1 {{variadic arguments must be in the end of the argument list}}
   llvm.func @variadic_inside(%arg0: i32, ..., %arg1: i32)

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe76979c9808a..aa2c43ff42b7b 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -33,20 +33,6 @@ llvm.func @dtor()
 // expected-error at +1{{'dtor' does not have a definition}}
 llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
 
-// -----
-
-// expected-error at +1{{expected llvm.noalias argument attribute to be a unit attribute}}
-func.func @invalid_noalias(%arg0: i32 {llvm.noalias = 3}) {
-  "llvm.return"() : () -> ()
-}
-
-// -----
-
-// expected-error at +1{{llvm.align argument attribute of non integer type}}
-func.func @invalid_align(%arg0: i32 {llvm.align = "foo"}) {
-  "llvm.return"() : () -> ()
-}
-
 ////////////////////////////////////////////////////////////////////////////////
 
 // Check that parser errors are properly produced and do not crash the compiler.

diff  --git a/mlir/test/Dialect/LLVMIR/parameter-attrs-invalid.mlir b/mlir/test/Dialect/LLVMIR/parameter-attrs-invalid.mlir
new file mode 100644
index 0000000000000..f6305f171013c
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/parameter-attrs-invalid.mlir
@@ -0,0 +1,188 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// Argument attributes
+
+// expected-error at below {{"llvm.noalias" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_noalias_arg_type(%0 : i32 {llvm.noalias})
+
+// -----
+
+// expected-error at below {{"llvm.noalias" should be a unit attribute}}
+llvm.func @invalid_noalias_attr_type(%0 : !llvm.ptr {llvm.noalias = 10 : i32})
+
+// -----
+
+// expected-error at below {{"llvm.readonly" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_readonly_arg_type(%0 : i32 {llvm.readonly})
+
+// -----
+
+// expected-error at below {{"llvm.readonly" should be a unit attribute}}
+llvm.func @invalid_readonly_attr_type(%0 : i32 {llvm.readonly = i32})
+
+// -----
+
+// expected-error at below {{"llvm.nest" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_nest_arg_type(%0 : i32 {llvm.nest})
+
+// -----
+
+// expected-error at below {{"llvm.nest" should be a unit attribute}}
+llvm.func @invalid_nest_attr_type(%0 : i32 {llvm.nest = "foo"})
+
+// -----
+
+// expected-error at below {{"llvm.align" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_align_arg_type(%0 : i32 {llvm.align = 10 : i32})
+
+// -----
+
+// expected-error at below {{"llvm.align" should be an integer attribute}}
+llvm.func @invalid_align_attr_type(%0 : i32 {llvm.align = "foo"})
+
+// -----
+
+// expected-error at below {{"llvm.sret" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_sret_arg_type(%0 : i32 {llvm.sret = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.sret" attribute attached to LLVM pointer argument of 
diff erent type}}
+llvm.func @invalid_sret_attr_type(%0 : !llvm.ptr<f32> {llvm.sret = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.byval" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_byval_arg_type(%0 : i32 {llvm.byval = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.byval" attribute attached to LLVM pointer argument of 
diff erent type}}
+llvm.func @invalid_byval_attr_type(%0 : !llvm.ptr<!llvm.struct<(f32)>> {llvm.byval = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.byref" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_byref_arg_type(%0 : i32 {llvm.byref = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.byref" attribute attached to LLVM pointer argument of 
diff erent type}}
+llvm.func @invalid_byref_attr_type(%0 : !llvm.ptr<!llvm.struct<(f32)>> {llvm.byref = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.inalloca" attribute attached to non-pointer LLVM type}}
+llvm.func @invalid_inalloca_arg_type(%0 : i32 {llvm.inalloca = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.inalloca" attribute attached to LLVM pointer argument of 
diff erent type}}
+llvm.func @invalid_inalloca_attr_type(%0 : !llvm.ptr<!llvm.struct<(f32)>> {llvm.inalloca = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.signext" attribute attached to non-integer LLVM type}}
+llvm.func @invalid_signext_arg_type(%0 : f32 {llvm.signext})
+
+// -----
+
+// expected-error at below {{"llvm.signext" should be a unit attribute}}
+llvm.func @invalid_signext_attr_type(%0 : i32 {llvm.signext = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.zeroext" attribute attached to non-integer LLVM type}}
+llvm.func @invalid_zeroext_arg_type(%0 : f32 {llvm.zeroext})
+
+// -----
+
+// expected-error at below {{"llvm.zeroext" should be a unit attribute}}
+llvm.func @invalid_zeroext_attr_type(%0 : i32 {llvm.zeroext = !llvm.struct<(i32)>})
+
+// -----
+
+// expected-error at below {{"llvm.noundef" should be a unit attribute}}
+llvm.func @invalid_noundef_attr_type(%0 : i32 {llvm.noundef = !llvm.ptr})
+
+// -----
+
+// Result attributes
+
+// expected-error at below {{cannot attach result attributes to functions with a void return}}
+llvm.func @void_def() -> (!llvm.void {llvm.noundef})
+
+// -----
+
+// expected-error @below{{"llvm.align" should be an integer attribute}}
+llvm.func @alignattr_ret() -> (!llvm.ptr {llvm.align = 1.0 : f32})
+
+// -----
+
+// expected-error @below{{"llvm.align" attribute attached to non-pointer LLVM type}}
+llvm.func @alignattr_ret() -> (i32 {llvm.align = 4})
+
+// -----
+
+// expected-error @below{{"llvm.noalias" should be a unit attribute}}
+llvm.func @noaliasattr_ret() -> (!llvm.ptr {llvm.noalias = 1})
+
+// -----
+
+// expected-error @below{{"llvm.noalias" attribute attached to non-pointer LLVM type}}
+llvm.func @noaliasattr_ret() -> (i32 {llvm.noalias})
+
+// -----
+
+// expected-error @below{{"llvm.noundef" should be a unit attribute}}
+llvm.func @noundefattr_ret() -> (!llvm.ptr {llvm.noundef = 1})
+
+// -----
+
+// expected-error @below{{"llvm.signext" should be a unit attribute}}
+llvm.func @signextattr_ret() -> (i32 {llvm.signext = 1})
+
+// -----
+
+// expected-error @below{{"llvm.signext" attribute attached to non-integer LLVM type}}
+llvm.func @signextattr_ret() -> (f32 {llvm.signext})
+
+// -----
+
+// expected-error @below{{"llvm.zeroext" should be a unit attribute}}
+llvm.func @zeroextattr_ret() -> (i32 {llvm.zeroext = 1})
+
+// -----
+
+// expected-error @below{{"llvm.zeroext" attribute attached to non-integer LLVM type}}
+llvm.func @zeroextattr_ret() -> (f32 {llvm.zeroext})
+
+// -----
+
+// expected-error @below{{"llvm.readonly" is not a valid result attribute}}
+llvm.func @readonly_ret() -> (f32 {llvm.readonly})
+
+// -----
+
+// expected-error @below{{"llvm.nest" is not a valid result attribute}}
+llvm.func @nest_ret() -> (f32 {llvm.nest})
+
+// -----
+
+// expected-error @below{{"llvm.sret" is not a valid result attribute}}
+llvm.func @sret_ret() -> (!llvm.ptr {llvm.sret = i64})
+
+// -----
+
+// expected-error @below{{"llvm.byval" is not a valid result attribute}}
+llvm.func @byval_ret() -> (!llvm.ptr {llvm.byval = i64})
+
+// -----
+
+// expected-error @below{{"llvm.byref" is not a valid result attribute}}
+llvm.func @byref_ret() -> (!llvm.ptr {llvm.byref = i64})
+
+// -----
+
+// expected-error @below{{"llvm.inalloca" is not a valid result attribute}}
+llvm.func @inalloca_ret() -> (!llvm.ptr {llvm.inalloca = i64})

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index e8571a992e5c3..19e0b1501e090 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -7,96 +7,6 @@ func.func @foo() {
 
 // -----
 
-// expected-error @below{{llvm.noalias attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_noalias(%arg0 : f32 {llvm.noalias}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.sret attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_sret(%arg0 : f32 {llvm.sret = f32}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.sret attribute attached to LLVM pointer argument of a 
diff erent type}}
-llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.sret = i32}) -> !llvm.ptr<f32> {
-  llvm.return %arg0 : !llvm.ptr<f32>
-}
-
-// -----
-
-// expected-error @below{{llvm.nest attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_nest(%arg0 : f32 {llvm.nest}) -> f32 {
-  llvm.return %arg0 : f32
-}
-// -----
-
-// expected-error @below{{llvm.byval attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_byval(%arg0 : f32 {llvm.byval = f32}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.byval attribute attached to LLVM pointer argument of a 
diff erent type}}
-llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.byval = i32}) -> !llvm.ptr<f32> {
-  llvm.return %arg0 : !llvm.ptr<f32>
-}
-
-// -----
-
-// expected-error @below{{llvm.byref attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_byval(%arg0 : f32 {llvm.byref = f32}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.byref attribute attached to LLVM pointer argument of a 
diff erent type}}
-llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.byref = i32}) -> !llvm.ptr<f32> {
-  llvm.return %arg0 : !llvm.ptr<f32>
-}
-
-// -----
-
-// expected-error @below{{llvm.inalloca attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_byval(%arg0 : f32 {llvm.inalloca = f32}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.inalloca attribute attached to LLVM pointer argument of a 
diff erent type}}
-llvm.func @invalid_sret(%arg0 : !llvm.ptr<f32> {llvm.inalloca = i32}) -> !llvm.ptr<f32> {
-  llvm.return %arg0 : !llvm.ptr<f32>
-}
-
-// -----
-
-// expected-error @below{{llvm.align attribute attached to LLVM non-pointer argument}}
-llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 {
-  llvm.return %arg0 : f32
-}
-
-// -----
-
-// expected-error @below{{llvm.signext attribute attached to LLVM non-integer argument}}
-llvm.func @invalid_signext(%arg0: f32 {llvm.signext}) {
-  "llvm.return"() : () -> ()
-}
-
-// -----
-
-// expected-error @below{{llvm.zeroext attribute attached to LLVM non-integer argument}}
-llvm.func @invalid_zeroext(%arg0: f32 {llvm.zeroext}) {
-  "llvm.return"() : () -> ()
-}
-
-// -----
-
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
   // expected-error @below{{expected struct type to be a complex number}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
@@ -348,10 +258,3 @@ llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {
   %0 = llvm.intr.experimental.stepvector : vector<7xf32>
   llvm.return %0 : vector<7xf32>
 }
-
-// -----
-
-// expected-error @below{{llvm.readonly attribute attached to LLVM non-pointer argument}}
-llvm.func @wrong_readonly_attribute(%vec : f32 {llvm.readonly}) {
-  llvm.return
-}


        


More information about the Mlir-commits mailing list