[Mlir-commits] [mlir] 2582b2e - [mlir][llvm] Add LLVM TargetExtType

Christian Ulmann llvmlistbot at llvm.org
Tue May 30 08:56:52 PDT 2023


Author: Lukas Sommer
Date: 2023-05-30T15:55:45Z
New Revision: 2582b2e3ac19d3723daf6960b1edb7c0b627ff20

URL: https://github.com/llvm/llvm-project/commit/2582b2e3ac19d3723daf6960b1edb7c0b627ff20
DIFF: https://github.com/llvm/llvm-project/commit/2582b2e3ac19d3723daf6960b1edb7c0b627ff20.diff

LOG: [mlir][llvm] Add LLVM TargetExtType

Add support for the `llvm::TargetExtType` to the MLIR LLVM dialect.

Target extension types were introduced to represent target-specific types, which are opaque to the compiler and optimizations.

The patch also enforces some of the constraints defined for the target extension type in the LLVM language reference manual.

Signed-off-by: Lukas Sommer <lukas.sommer at codeplay.com>

Reviewed By: ftynse, gysit, Dinistro

Differential Revision: https://reviews.llvm.org/D151446

Added: 
    mlir/test/Target/LLVMIR/Import/target-ext-type.ll
    mlir/test/Target/LLVMIR/target-ext-type.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Target/LLVMIR/ModuleImport.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
    mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
    mlir/test/Dialect/LLVMIR/global.mlir
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/types-invalid.mlir
    mlir/test/Dialect/LLVMIR/types.mlir
    mlir/test/Target/LLVMIR/llvmir-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 00a3b80ecf69..809e8ed4a332 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -84,10 +84,23 @@ def LLVM_OpaqueStruct : Type<
   And<[LLVM_AnyStruct.predicate,
        CPred<"::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).isOpaque()">]>>;
 
+// Type constraint accepting any LLVM target extension type.
+def LLVM_AnyTargetExt : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMTargetExtType>($_self)">,
+                            "LLVM target extension type">;
+
+// Type constraint accepting LLVM target extension types with no support for
+// memory operations such as alloca, load and store.
+def LLVM_NonLoadableTargetExtType : Type<
+  And<[LLVM_AnyTargetExt.predicate, 
+        CPred<"!::llvm::cast<::mlir::LLVM::LLVMTargetExtType>($_self).supportsMemOps()">]
+        >>;
+
 // Type constraint accepting any LLVM type that can be loaded or stored, i.e. a
-// type that has size (not void, function or opaque struct type).
+// type that has size (not void, function, opaque struct type or target 
+// extension type which does not support memory operations).
 def LLVM_LoadableType : Type<
-  Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
+  Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>,
+          Neg<LLVM_NonLoadableTargetExtType.predicate>]>,
       LLVM_PointerElementTypeInterface.predicate]>,
   "LLVM type with size">;
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index e26d9d8acc79..5f746e21e831 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -225,4 +225,40 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMTargetExtType
+//===----------------------------------------------------------------------===//
+
+def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
+  let summary = "LLVM target-specific extension type";
+  let description = [{
+    LLVM dialect target extension type, which are generally unintrospectable
+    from target-independent optimizations.
+
+    Target extension types have a string name, and optionally have type and/or
+    integer parameters. The exact meaning of any parameters is dependent on the
+    target.
+  }];
+
+  let parameters = (ins StringRefParameter<>:$extTypeName,
+                        OptionalArrayRefParameter<"Type">:$typeParams,
+                        OptionalArrayRefParameter<"unsigned int">:$intParams);
+
+  let assemblyFormat = [{
+    `<` $extTypeName (`,` custom<ExtTypeParams>($typeParams, $intParams)^ )? `>`
+  }];
+
+  let extraClassDeclaration = [{
+    enum Property {
+      /// zeroinitializer is valid for this target extension type.
+      HasZeroInit = 1U << 0,
+      /// This type may be used as the value type of a global variable.
+      CanBeGlobal = 1U << 1,
+    };
+
+    bool hasProperty(Property Prop) const;
+    bool supportsMemOps() const;
+  }];
+}
+
 #endif // LLVMTYPES_TD

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 06e8d79f68a6..a31daaab17e1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -266,8 +266,19 @@ static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType,
 }
 
 LogicalResult AllocaOp::verify() {
-  return verifyOpaquePtr(getOperation(), llvm::cast<LLVMPointerType>(getType()),
-                         getElemType());
+  LLVMPointerType ptrType = llvm::cast<LLVMPointerType>(getType());
+  if (failed(verifyOpaquePtr(getOperation(), ptrType, getElemType())))
+    return failure();
+
+  Type elemTy =
+      (ptrType.isOpaque()) ? *getElemType() : ptrType.getElementType();
+  // Only certain target extension types can be used in 'alloca'.
+  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(elemTy);
+      targetExtType && !targetExtType.supportsMemOps())
+    return emitOpError()
+           << "this target extension type cannot be used in alloca";
+
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1832,6 +1843,22 @@ LogicalResult GlobalOp::verify() {
           "attribute");
   }
 
+  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
+    if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
+      return emitOpError()
+             << "this target extension type cannot be used in a global";
+
+    if (Attribute value = getValueOrNull()) {
+      // Only a single, zero integer attribute (=zeroinitializer) is allowed for
+      // a global value with TargetExtType.
+      // TODO: Replace with 'zeroinitializer' once there is a dedicated
+      // zeroinitializer operation in the LLVM dialect.
+      if (!isa<IntegerAttr>(value) || !isZeroAttribute(value))
+        return emitOpError()
+               << "expected zero value for global with target extension type";
+    }
+  }
+
   if (getLinkage() == Linkage::Common) {
     if (Attribute value = getValueOrNull()) {
       if (!isZeroAttribute(value)) {
@@ -2288,6 +2315,18 @@ LogicalResult LLVM::ConstantOp::verify() {
     }
     return success();
   }
+  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
+    if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
+      return emitOpError()
+             << "target extension type does not support zero-initializer";
+    // Only a single, zero integer attribute (=zeroinitializer) is allowed for a
+    // global value with TargetExtType.
+    if (!isa<IntegerAttr>(getValue()) || !isZeroAttribute(getValue()))
+      return emitOpError()
+             << "only zero-initializer allowed for target extension types";
+
+    return success();
+  }
   if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
     return emitOpError()
            << "only supports integer, float, string or elements attributes";

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 742ab5b63266..afb8c9060619 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -45,6 +45,7 @@ static StringRef getTypeKeyword(Type type) {
           [&](Type) { return "vec"; })
       .Case<LLVMArrayType>([&](Type) { return "array"; })
       .Case<LLVMStructType>([&](Type) { return "struct"; })
+      .Case<LLVMTargetExtType>([&](Type) { return "target"; })
       .Default([](Type) -> StringRef {
         llvm_unreachable("unexpected 'llvm' type kind");
       });
@@ -119,7 +120,7 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
 
   llvm::TypeSwitch<Type>(type)
       .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType,
-            LLVMScalableVectorType, LLVMFunctionType>(
+            LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType>(
           [&](auto type) { type.print(printer); })
       .Case([&](LLVMStructType structType) {
         printStructType(printer, structType);
@@ -332,6 +333,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("vec", [&] { return parseVectorType(parser); })
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })
+      .Case("target", [&] { return LLVMTargetExtType::parse(parser); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
         return Type();

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index be129ffe2aad..95d76a14d2bd 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -109,6 +109,59 @@ static void printPointer(AsmPrinter &p, Type elementType,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// custom<ExtTypeParams>
+//===----------------------------------------------------------------------===//
+
+/// Parses the parameter list for a target extension type. The parameter list
+/// contains an optional list of type parameters, followed by an optional list
+/// of integer parameters. Type and integer parameters cannot be interleaved in
+/// the list.
+/// extTypeParams ::= typeList? | intList? | (typeList "," intList)
+/// typeList      ::= type ("," type)*
+/// intList       ::= integer ("," integer)*
+static ParseResult
+parseExtTypeParams(AsmParser &p, SmallVectorImpl<Type> &typeParams,
+                   SmallVectorImpl<unsigned int> &intParams) {
+  bool parseType = true;
+  auto typeOrIntParser = [&]() -> ParseResult {
+    unsigned int i;
+    auto intResult = p.parseOptionalInteger(i);
+    if (intResult.has_value() && !failed(*intResult)) {
+      // Successfully parsed an integer.
+      intParams.push_back(i);
+      // After the first integer was successfully parsed, no
+      // more types can be parsed.
+      parseType = false;
+      return success();
+    }
+    if (parseType) {
+      Type t;
+      if (!parsePrettyLLVMType(p, t)) {
+        // Successfully parsed a type.
+        typeParams.push_back(t);
+        return success();
+      }
+    }
+    return failure();
+  };
+  if (p.parseCommaSeparatedList(typeOrIntParser)) {
+    p.emitError(p.getCurrentLocation(),
+                "failed to parse parameter list for target extension type");
+    return failure();
+  }
+  return success();
+}
+
+static void printExtTypeParams(AsmPrinter &p, ArrayRef<Type> typeParams,
+                               ArrayRef<unsigned int> intParams) {
+  p << typeParams;
+  if (!typeParams.empty() && !intParams.empty())
+    p << ", ";
+
+  p << intParams;
+}
+
 //===----------------------------------------------------------------------===//
 // ODS-Generated Definitions
 //===----------------------------------------------------------------------===//
@@ -721,6 +774,35 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
       emitError, elementType, numElements);
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMTargetExtType.
+//===----------------------------------------------------------------------===//
+
+static constexpr llvm::StringRef kSpirvPrefix = "spirv.";
+static constexpr llvm::StringRef kArmSVCount = "aarch64.svcount";
+
+bool LLVM::LLVMTargetExtType::hasProperty(Property prop) const {
+  // See llvm/lib/IR/Type.cpp for reference.
+  uint64_t properties = 0;
+
+  if (getExtTypeName().starts_with(kSpirvPrefix))
+    properties |=
+        (LLVMTargetExtType::HasZeroInit | LLVM::LLVMTargetExtType::CanBeGlobal);
+
+  return (properties & prop) == prop;
+}
+
+bool LLVM::LLVMTargetExtType::supportsMemOps() const {
+  // See llvm/lib/IR/Type.cpp for reference.
+  if (getExtTypeName().starts_with(kSpirvPrefix))
+    return true;
+
+  if (getExtTypeName() == kArmSVCount)
+    return true;
+
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//
@@ -746,6 +828,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
       LLVMTokenType,
       LLVMFixedVectorType,
       LLVMScalableVectorType,
+      LLVMTargetExtType,
       LLVMVoidType,
       LLVMX86MMXType
     >(type)) {
@@ -791,6 +874,9 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
               return true;
             return isCompatible(pointerType.getElementType());
           })
+          .Case<LLVMTargetExtType>([&](auto extType) {
+            return llvm::all_of(extType.getTypeParams(), isCompatible);
+          })
           // clang-format off
           .Case<
               LLVMFixedVectorType,
@@ -974,7 +1060,8 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
       .Default([](Type ty) {
         assert((llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
                           LLVMTokenType, LLVMStructType, LLVMArrayType,
-                          LLVMPointerType, LLVMFunctionType>(ty)) &&
+                          LLVMPointerType, LLVMFunctionType, LLVMTargetExtType>(
+                   ty)) &&
                "unexpected missing support for primitive type");
         return llvm::TypeSize::Fixed(0);
       });

diff  --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 05d6b7827d83..5f9eb1835cd2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1065,6 +1065,23 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
     return root;
   }
 
+  if (auto *constTargetNone = dyn_cast<llvm::ConstantTargetNone>(constant)) {
+    LLVMTargetExtType targetExtType =
+        cast<LLVMTargetExtType>(convertType(constTargetNone->getType()));
+    assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) &&
+           "target extension type does not support zero-initialization");
+    // As the number of values needed for initialization is target-specific and
+    // opaque to the compiler, use a single i64 zero-valued attribute to
+    // represent the 'zeroinitializer', which is the only constant value allowed
+    // for target extension types (besides poison and undef).
+    // TODO: Replace with 'zeroinitializer' once there is a dedicated
+    // zeroinitializer operation in the LLVM dialect.
+    return builder
+        .create<LLVM::ConstantOp>(loc, targetExtType,
+                                  builder.getI64IntegerAttr(0))
+        .getRes();
+  }
+
   StringRef error = "";
   if (isa<llvm::BlockAddress>(constant))
     error = " since blockaddress(...) is unsupported";

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 772721e31e1c..9b8e9a3ee1f3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -342,6 +342,16 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
       return nullptr;
     return llvm::ConstantStruct::get(structType, {real, imag});
   }
+  if (auto *targetExtType = dyn_cast<::llvm::TargetExtType>(llvmType)) {
+    // TODO: Replace with 'zeroinitializer' once there is a dedicated
+    // zeroinitializer operation in the LLVM dialect.
+    auto intAttr = dyn_cast<IntegerAttr>(attr);
+    if (!intAttr || intAttr.getInt() != 0)
+      emitError(loc,
+                "Only zero-initialization allowed for target extension type");
+
+    return llvm::ConstantTargetNone::get(targetExtType);
+  }
   // For integer types, we allow a mismatch in sizes as the index type in
   // MLIR might have a 
diff erent size than the index type in the LLVM module.
   if (auto intAttr = dyn_cast<IntegerAttr>(attr))

diff  --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index 26e426b02327..458e71953e6c 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -36,7 +36,7 @@ class TypeFromLLVMIRTranslatorImpl {
         llvm::TypeSwitch<llvm::Type *, Type>(type)
             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
-                  llvm::ScalableVectorType>(
+                  llvm::ScalableVectorType, llvm::TargetExtType>(
                 [this](auto *type) { return this->translate(type); })
             .Default([this](llvm::Type *type) {
               return translatePrimitiveType(type);
@@ -135,6 +135,15 @@ class TypeFromLLVMIRTranslatorImpl {
         translateType(type->getElementType()), type->getMinNumElements());
   }
 
+  /// Translates the given target extension type.
+  Type translate(llvm::TargetExtType *type) {
+    SmallVector<Type> typeParams;
+    translateTypes(type->type_params(), typeParams);
+
+    return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams,
+                                        type->int_params());
+  }
+
   /// Translates a list of types.
   void translateTypes(ArrayRef<llvm::Type *> types,
                       SmallVectorImpl<Type> &result) {

diff  --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
index d3ecede27915..6d8b415ff09d 100644
--- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp
@@ -73,7 +73,7 @@ class TypeToLLVMIRTranslatorImpl {
             .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType,
                   LLVM::LLVMPointerType, LLVM::LLVMStructType,
                   LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType,
-                  VectorType>(
+                  VectorType, LLVM::LLVMTargetExtType>(
                 [this](auto type) { return this->translate(type); })
             .Default([](Type t) -> llvm::Type * {
               llvm_unreachable("unknown LLVM dialect type");
@@ -155,6 +155,14 @@ class TypeToLLVMIRTranslatorImpl {
                                          type.getMinNumElements());
   }
 
+  /// Translates the given target extension type.
+  llvm::Type *translate(LLVM::LLVMTargetExtType type) {
+    SmallVector<llvm::Type *> typeParams;
+    translateTypes(type.getTypeParams(), typeParams);
+    return llvm::TargetExtType::get(context, type.getExtTypeName(), typeParams,
+                                    type.getIntParams());
+  }
+
   /// Translates a list of types.
   void translateTypes(ArrayRef<Type> types,
                       SmallVectorImpl<llvm::Type *> &result) {

diff  --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir
index c53fdeff925d..00b73f0549fa 100644
--- a/mlir/test/Dialect/LLVMIR/global.mlir
+++ b/mlir/test/Dialect/LLVMIR/global.mlir
@@ -232,3 +232,16 @@ llvm.func @dtor() {
 
 // CHECK: llvm.mlir.global_dtors {dtors = [@dtor], priorities = [0 : i32]}
 llvm.mlir.global_dtors { dtors = [@dtor], priorities = [0 : i32]}
+
+// -----
+
+// CHECK: llvm.mlir.global external @target_ext() {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0>
+llvm.mlir.global @target_ext() : !llvm.target<"spirv.Image", i32, 0>
+
+// CHECK: llvm.mlir.global external @target_ext_init(0 : i64) {addr_space = 0 : i32} : !llvm.target<"spirv.Image", i32, 0>
+llvm.mlir.global @target_ext_init(0 : i64) : !llvm.target<"spirv.Image", i32, 0>
+
+// -----
+
+// expected-error @+1 {{expected zero value for global with target extension type}}
+llvm.mlir.global @target_fail(1 : i64) : !llvm.target<"spirv.Image", i32, 0>

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index aa3498a5ee95..b88619b1e388 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1387,3 +1387,39 @@ func.func @invalid_bitcast_addr_cast_vec(%arg : !llvm.vec<4 x ptr<1>>) {
   // expected-error at +1 {{cannot cast pointers of 
diff erent address spaces, use 'llvm.addrspacecast' instead}}
   %0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr>
 }
+
+// -----
+
+func.func @invalid_target_ext_alloca() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  // expected-error at +1 {{this target extension type cannot be used in alloca}}
+  %1 = llvm.alloca %0 x !llvm.target<"no_alloca"> : (i64) -> !llvm.ptr
+}
+
+// -----
+
+func.func @invalid_target_ext_load(%arg0 : !llvm.ptr) {
+  // expected-error at +1 {{result #0 must be LLVM type with size, but got '!llvm.target<"no_load">'}}
+  %0 = llvm.load %arg0 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"no_load">
+}
+
+// -----
+
+func.func @invalid_target_ext_atomic(%arg0 : !llvm.ptr) {
+  // expected-error at +1 {{unsupported type '!llvm.target<"spirv.Event">' for atomic access}}
+  %0 = llvm.load %arg0 atomic monotonic {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event">
+}
+
+// -----
+
+func.func @invalid_target_ext_constant() {
+  // expected-error at +1 {{target extension type does not support zero-initializer}}
+  %0 = llvm.mlir.constant(0 : index) : !llvm.target<"invalid_constant">
+}
+
+// -----
+
+func.func @invalid_target_ext_constant() {
+  // expected-error at +1 {{only zero-initializer allowed for target extension types}}
+  %0 = llvm.mlir.constant(42 : index) : !llvm.target<"spirv.Event">
+}

diff  --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
index fce100e6a865..f06f056cf490 100644
--- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
@@ -158,3 +158,18 @@ func.func private @unexpected_type() -> !llvm.f32
 
 // expected-error @below {{cannot use !llvm.vec for built-in primitives, use 'vector' instead}}
 func.func private @llvm_vector_primitive() -> !llvm.vec<4 x f32>
+
+// -----
+
+func.func private @target_ext_invalid_order() {
+  // expected-error @+1 {{failed to parse parameter list for target extension type}}
+  "some.op"() : () -> !llvm.target<"target1", 5, i32, 1>
+}
+
+// -----
+
+func.func private @target_ext_no_name() {
+  // expected-error at below {{expected string}}
+  // expected-error at below {{failed to parse LLVMTargetExtType parameter 'extTypeName' which is to be a `::llvm::StringRef`}}
+  "some.op"() : () -> !llvm.target<i32, 42>
+}

diff  --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir
index 42352ce697f0..c9bce337a3b8 100644
--- a/mlir/test/Dialect/LLVMIR/types.mlir
+++ b/mlir/test/Dialect/LLVMIR/types.mlir
@@ -176,3 +176,20 @@ llvm.func @aliases() {
   "some.op"() : () -> !llvm.struct<(i32, f32, !qux)>
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: ext_target
+llvm.func @ext_target() {
+    // CHECK: !llvm.target<"target1", i32, 1>
+    %0 = "some.op"() : () -> !llvm.target<"target1", i32, 1>
+    // CHECK: !llvm.target<"target2">
+    %1 = "some.op"() : () -> !llvm.target<"target2">
+    // CHECK: !llvm.target<"target3", i32, i64, f64>
+    %2 = "some.op"() : () -> !llvm.target<"target3", i32, i64, f64>
+    // CHECK: !llvm.target<"target4", 1, 0, 42>
+    %3 = "some.op"() : () -> !llvm.target<"target4", 1, 0, 42>
+    // CHECK: !llvm.target<"target5", i32, f64, 0, 5>
+    %4 = "some.op"() : () -> !llvm.target<"target5", i32, f64, 0, 5>
+    llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/Import/target-ext-type.ll b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll
new file mode 100644
index 000000000000..62194cad9152
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/target-ext-type.ll
@@ -0,0 +1,53 @@
+; RUN: mlir-translate -import-llvm %s | FileCheck %s
+
+; CHECK-LABEL: llvm.mlir.global external @global() {addr_space = 0 : i32}
+; CHECK-SAME:    !llvm.target<"spirv.DeviceEvent">
+; CHECK-NEXT:      %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+; CHECK-NEXT:      llvm.return %0 : !llvm.target<"spirv.DeviceEvent">
+ at global = global target("spirv.DeviceEvent") zeroinitializer
+
+; CHECK-LABEL: llvm.func spir_kernelcc @func1(
+define spir_kernel void @func1(
+  ; CHECK-SAME: %arg0: !llvm.target<"spirv.Pipe", 0>
+  target("spirv.Pipe", 0) %a,
+  ; CHECK-SAME:    %arg1: !llvm.target<"spirv.Pipe", 1>
+  target("spirv.Pipe", 1) %b,
+  ; CHECK-SAME:    %arg2: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 0>
+  target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %c1,
+  ; CHECK-SAME:    %arg3: !llvm.target<"spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0>
+  target("spirv.Image", i32, 1, 0, 0, 0, 0, 0, 0) %d1,
+  ; CHECK-SAME:    %arg4: !llvm.target<"spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0>
+  target("spirv.Image", i32, 2, 0, 0, 0, 0, 0, 0) %e1,
+  ; CHECK-SAME:    %arg5: !llvm.target<"spirv.Image", f16, 1, 0, 1, 0, 0, 0, 0>
+  target("spirv.Image", half, 1, 0, 1, 0, 0, 0, 0) %f1,
+  ; CHECK-SAME:    %arg6: !llvm.target<"spirv.Image", f32, 5, 0, 0, 0, 0, 0, 0>
+  target("spirv.Image", float, 5, 0, 0, 0, 0, 0, 0) %g1,
+  ; CHECK-SAME:    %arg7: !llvm.target<"spirv.Image", !llvm.void, 0, 0, 0, 0, 0, 0, 1>
+  target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 1) %c2,
+  ; CHECK-SAME:    %arg8: !llvm.target<"spirv.Image", !llvm.void, 1, 0, 0, 0, 0, 0, 2>)
+  target("spirv.Image", void, 1, 0, 0, 0, 0, 0, 2) %d3) {
+entry:
+  ret void
+}
+
+; CHECK-LABEL: llvm.func @func2()
+; CHECK-SAME:      !llvm.target<"spirv.Event"> {  
+define target("spirv.Event") @func2() {
+  ; CHECK-NEXT:    %0 = llvm.mlir.constant(1 : i32) : i32
+  ; CHECK-NEXT:    %1 = llvm.mlir.poison : !llvm.target<"spirv.Event">
+  ; CHECK-NEXT:    %2 = llvm.alloca %0 x !llvm.target<"spirv.Event"> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %mem = alloca target("spirv.Event")
+  ; CHECK-NEXT:    %3 = llvm.load %2 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event">
+  %val = load target("spirv.Event"), ptr %mem
+  ; CHECK-NEXT:    llvm.return %1 : !llvm.target<"spirv.Event">
+  ret target("spirv.Event") poison
+}
+
+; CHECK-LABEL: llvm.func @func3()
+define void @func3() {
+  ; CHECK-NEXT:    %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  ; CHECK-NEXT:    %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent">
+  %val = freeze target("spirv.DeviceEvent") zeroinitializer
+  ; CHECK-NEXT:    llvm.return
+  ret void
+}

diff  --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir
index 9d972f6fa6b6..a92d46dfadfe 100644
--- a/mlir/test/Target/LLVMIR/llvmir-types.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir
@@ -141,6 +141,18 @@ llvm.func @return_s_sp_i32() -> !llvm.struct<(struct<packed (i32)>)>
 // CHECK: declare <{ { i32 } }> @return_sp_s_i32()
 llvm.func @return_sp_s_i32() -> !llvm.struct<packed (struct<(i32)>)>
 
+// CHECK: declare target("target-no-param") @return_target_ext_no_param()
+llvm.func @return_target_ext_no_param() -> !llvm.target<"target-no-param">
+
+// CHECK: declare target("target-type-param", i32, double) @return_target_ext_type_params()
+llvm.func @return_target_ext_type_params() -> !llvm.target<"target-type-param", i32, f64>
+
+// CHECK: declare target("target-int-param", 0, 42) @return_target_ext_int_params()
+llvm.func @return_target_ext_int_params() -> !llvm.target<"target-int-param", 0, 42>
+
+// CHECK: declare target("target-params", i32, double, 0, 5) @return_target_ext_params()
+llvm.func @return_target_ext_params() -> !llvm.target<"target-params", i32, f64, 0, 5>
+
 // -----
 // Put structs into a separate split so that we can match their declarations
 // locally.

diff  --git a/mlir/test/Target/LLVMIR/target-ext-type.mlir b/mlir/test/Target/LLVMIR/target-ext-type.mlir
new file mode 100644
index 000000000000..e7004b2699dc
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/target-ext-type.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK: @global = global target("spirv.DeviceEvent") zeroinitializer
+llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.target<"spirv.DeviceEvent"> {
+  %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  llvm.return %0 : !llvm.target<"spirv.DeviceEvent">
+}
+
+// CHECK-LABEL: define target("spirv.Event") @func2() {
+// CHECK-NEXT:    %1 = alloca target("spirv.Event"), align 8
+// CHECK-NEXT:    %2 = load target("spirv.Event"), ptr %1, align 8
+// CHECK-NEXT:    ret target("spirv.Event") poison
+llvm.func @func2() -> !llvm.target<"spirv.Event"> {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  %1 = llvm.mlir.poison : !llvm.target<"spirv.Event">
+  %2 = llvm.alloca %0 x !llvm.target<"spirv.Event"> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %3 = llvm.load %2 {alignment = 8 : i64} : !llvm.ptr -> !llvm.target<"spirv.Event">
+  llvm.return %1 : !llvm.target<"spirv.Event">
+}
+
+// CHECK-LABEL: define void @func3() {
+// CHECK-NEXT:    %1 = freeze target("spirv.DeviceEvent") zeroinitializer
+// CHECK-NEXT:    ret void
+llvm.func @func3() {
+  %0 = llvm.mlir.constant(0 : i64) : !llvm.target<"spirv.DeviceEvent">
+  %1 = llvm.freeze %0 : !llvm.target<"spirv.DeviceEvent">
+  llvm.return
+}


        


More information about the Mlir-commits mailing list