[Mlir-commits] [mlir] d4fbbab - [mlir] translate types between MLIR LLVM dialect and LLVM IR

Alex Zinenko llvmlistbot at llvm.org
Tue Aug 4 04:43:01 PDT 2020


Author: Alex Zinenko
Date: 2020-08-04T13:42:43+02:00
New Revision: d4fbbab2e494a59480096a257136ed2b75d07e87

URL: https://github.com/llvm/llvm-project/commit/d4fbbab2e494a59480096a257136ed2b75d07e87
DIFF: https://github.com/llvm/llvm-project/commit/d4fbbab2e494a59480096a257136ed2b75d07e87.diff

LOG: [mlir] translate types between MLIR LLVM dialect and LLVM IR

With new LLVM dialect type modeling, the dialect types no longer wrap LLVM IR
types. Therefore, they need to be translated to and from LLVM IR during export
and import. Introduce the relevant functionality for translating types. It is
currently exercised by an ad-hoc type translation roundtripping test that will
be subsumed by the actual translation test when the type system transition is
complete.

Depends On D84339

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D85019

Added: 
    mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
    mlir/lib/Target/LLVMIR/TypeTranslation.cpp
    mlir/test/Target/llvmir-types.mlir
    mlir/test/lib/Target/CMakeLists.txt
    mlir/test/lib/Target/TestLLVMTypeTranslation.cpp

Modified: 
    mlir/lib/Target/CMakeLists.txt
    mlir/test/lib/CMakeLists.txt
    mlir/tools/mlir-translate/CMakeLists.txt
    mlir/tools/mlir-translate/mlir-translate.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
new file mode 100644
index 000000000000..5a82f0a096df
--- /dev/null
+++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
@@ -0,0 +1,36 @@
+//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the type translation function going from MLIR LLVM dialect
+// to LLVM IR and back.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
+#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
+
+namespace llvm {
+class LLVMContext;
+class Type;
+} // namespace llvm
+
+namespace mlir {
+
+class MLIRContext;
+
+namespace LLVM {
+
+class LLVMTypeNew;
+
+llvm::Type *translateTypeToLLVMIR(LLVMTypeNew type, llvm::LLVMContext &context);
+LLVMTypeNew translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H

diff  --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 4a0af66a04b1..5ca335b4b4b5 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
   LLVMIR/DebugTranslation.cpp
   LLVMIR/ModuleTranslation.cpp
+  LLVMIR/TypeTranslation.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR

diff  --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
new file mode 100644
index 000000000000..6163334d3b4e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
@@ -0,0 +1,309 @@
+//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/TypeTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/MLIRContext.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
+
+using namespace mlir;
+
+namespace {
+/// Support for translating MLIR LLVM dialect types to LLVM IR.
+class TypeToLLVMIRTranslator {
+public:
+  /// Constructs a class creating types in the given LLVM context.
+  TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {}
+
+  /// Translates a single type.
+  llvm::Type *translateType(LLVM::LLVMTypeNew type) {
+    // If the conversion is already known, just return it.
+    if (knownTranslations.count(type))
+      return knownTranslations.lookup(type);
+
+    // Dispatch to an appropriate function.
+    llvm::Type *translated =
+        llvm::TypeSwitch<LLVM::LLVMTypeNew, llvm::Type *>(type)
+            .Case([this](LLVM::LLVMVoidType) {
+              return llvm::Type::getVoidTy(context);
+            })
+            .Case([this](LLVM::LLVMHalfType) {
+              return llvm::Type::getHalfTy(context);
+            })
+            .Case([this](LLVM::LLVMBFloatType) {
+              return llvm::Type::getBFloatTy(context);
+            })
+            .Case([this](LLVM::LLVMFloatType) {
+              return llvm::Type::getFloatTy(context);
+            })
+            .Case([this](LLVM::LLVMDoubleType) {
+              return llvm::Type::getDoubleTy(context);
+            })
+            .Case([this](LLVM::LLVMFP128Type) {
+              return llvm::Type::getFP128Ty(context);
+            })
+            .Case([this](LLVM::LLVMX86FP80Type) {
+              return llvm::Type::getX86_FP80Ty(context);
+            })
+            .Case([this](LLVM::LLVMPPCFP128Type) {
+              return llvm::Type::getPPC_FP128Ty(context);
+            })
+            .Case([this](LLVM::LLVMX86MMXType) {
+              return llvm::Type::getX86_MMXTy(context);
+            })
+            .Case([this](LLVM::LLVMTokenType) {
+              return llvm::Type::getTokenTy(context);
+            })
+            .Case([this](LLVM::LLVMLabelType) {
+              return llvm::Type::getLabelTy(context);
+            })
+            .Case([this](LLVM::LLVMMetadataType) {
+              return llvm::Type::getMetadataTy(context);
+            })
+            .Case<LLVM::LLVMArrayType, LLVM::LLVMIntegerType,
+                  LLVM::LLVMFunctionType, LLVM::LLVMPointerType,
+                  LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
+                  LLVM::LLVMScalableVectorType>(
+                [this](auto array) { return translate(array); })
+            .Default([](LLVM::LLVMTypeNew t) -> llvm::Type * {
+              llvm_unreachable("unknown LLVM dialect type");
+            });
+
+    // Cache the result of the conversion and return.
+    knownTranslations.try_emplace(type, translated);
+    return translated;
+  }
+
+private:
+  /// Translates the given array type.
+  llvm::Type *translate(LLVM::LLVMArrayType type) {
+    return llvm::ArrayType::get(translateType(type.getElementType()),
+                                type.getNumElements());
+  }
+
+  /// Translates the given function type.
+  llvm::Type *translate(LLVM::LLVMFunctionType type) {
+    SmallVector<llvm::Type *, 8> paramTypes;
+    translateTypes(type.getParams(), paramTypes);
+    return llvm::FunctionType::get(translateType(type.getReturnType()),
+                                   paramTypes, type.isVarArg());
+  }
+
+  /// Translates the given integer type.
+  llvm::Type *translate(LLVM::LLVMIntegerType type) {
+    return llvm::IntegerType::get(context, type.getBitWidth());
+  }
+
+  /// Translates the given pointer type.
+  llvm::Type *translate(LLVM::LLVMPointerType type) {
+    return llvm::PointerType::get(translateType(type.getElementType()),
+                                  type.getAddressSpace());
+  }
+
+  /// Translates the given structure type, supports both identified and literal
+  /// structs. This will _create_ a new identified structure every time, use
+  /// `convertType` if a structure with the same name must be looked up instead.
+  llvm::Type *translate(LLVM::LLVMStructType type) {
+    SmallVector<llvm::Type *, 8> subtypes;
+    if (!type.isIdentified()) {
+      translateTypes(type.getBody(), subtypes);
+      return llvm::StructType::get(context, subtypes, type.isPacked());
+    }
+
+    llvm::StructType *structType =
+        llvm::StructType::create(context, type.getName());
+    // Mark the type we just created as known so that recursive calls can pick
+    // it up and use directly.
+    knownTranslations.try_emplace(type, structType);
+    if (type.isOpaque())
+      return structType;
+
+    translateTypes(type.getBody(), subtypes);
+    structType->setBody(subtypes, type.isPacked());
+    return structType;
+  }
+
+  /// Translates the given fixed-vector type.
+  llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
+    return llvm::FixedVectorType::get(translateType(type.getElementType()),
+                                      type.getNumElements());
+  }
+
+  /// Translates the given scalable-vector type.
+  llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
+    return llvm::ScalableVectorType::get(translateType(type.getElementType()),
+                                         type.getMinNumElements());
+  }
+
+  /// Translates a list of types.
+  void translateTypes(ArrayRef<LLVM::LLVMTypeNew> types,
+                      SmallVectorImpl<llvm::Type *> &result) {
+    result.reserve(result.size() + types.size());
+    for (auto type : types)
+      result.push_back(translateType(type));
+  }
+
+  /// Reference to the context in which the LLVM IR types are created.
+  llvm::LLVMContext &context;
+
+  /// Map of known translation. This serves a double purpose: caches translation
+  /// results to avoid repeated recursive calls and makes sure identified
+  /// structs with the same name (that is, equal) are resolved to an existing
+  /// type instead of creating a new type.
+  llvm::DenseMap<LLVM::LLVMTypeNew, llvm::Type *> knownTranslations;
+};
+} // end namespace
+
+/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain
+/// the mapping for identified structs so new structs will be created with
+/// auto-renaming on each call. This is intended exclusively for testing.
+llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMTypeNew type,
+                                              llvm::LLVMContext &context) {
+  return TypeToLLVMIRTranslator(context).translateType(type);
+}
+
+namespace {
+/// Support for translating LLVM IR types to MLIR LLVM dialect types.
+class TypeFromLLVMIRTranslator {
+public:
+  /// Constructs a class creating types in the given MLIR context.
+  TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {}
+
+  /// Translates the given type.
+  LLVM::LLVMTypeNew translateType(llvm::Type *type) {
+    if (knownTranslations.count(type))
+      return knownTranslations.lookup(type);
+
+    LLVM::LLVMTypeNew translated =
+        llvm::TypeSwitch<llvm::Type *, LLVM::LLVMTypeNew>(type)
+            .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
+                  llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
+                  llvm::ScalableVectorType>(
+                [this](auto *type) { return translate(type); })
+            .Default([this](llvm::Type *type) {
+              return translatePrimitiveType(type);
+            });
+    knownTranslations.try_emplace(type, translated);
+    return translated;
+  }
+
+private:
+  /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
+  /// type.
+  LLVM::LLVMTypeNew translatePrimitiveType(llvm::Type *type) {
+    if (type->isVoidTy())
+      return LLVM::LLVMVoidType::get(&context);
+    if (type->isHalfTy())
+      return LLVM::LLVMHalfType::get(&context);
+    if (type->isBFloatTy())
+      return LLVM::LLVMBFloatType::get(&context);
+    if (type->isFloatTy())
+      return LLVM::LLVMFloatType::get(&context);
+    if (type->isDoubleTy())
+      return LLVM::LLVMDoubleType::get(&context);
+    if (type->isFP128Ty())
+      return LLVM::LLVMFP128Type::get(&context);
+    if (type->isX86_FP80Ty())
+      return LLVM::LLVMX86FP80Type::get(&context);
+    if (type->isPPC_FP128Ty())
+      return LLVM::LLVMPPCFP128Type::get(&context);
+    if (type->isX86_MMXTy())
+      return LLVM::LLVMX86MMXType::get(&context);
+    if (type->isLabelTy())
+      return LLVM::LLVMLabelType::get(&context);
+    if (type->isMetadataTy())
+      return LLVM::LLVMMetadataType::get(&context);
+    llvm_unreachable("not a primitive type");
+  }
+
+  /// Translates the given array type.
+  LLVM::LLVMTypeNew translate(llvm::ArrayType *type) {
+    return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
+                                    type->getNumElements());
+  }
+
+  /// Translates the given function type.
+  LLVM::LLVMTypeNew translate(llvm::FunctionType *type) {
+    SmallVector<LLVM::LLVMTypeNew, 8> paramTypes;
+    translateTypes(type->params(), paramTypes);
+    return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
+                                       paramTypes, type->isVarArg());
+  }
+
+  /// Translates the given integer type.
+  LLVM::LLVMTypeNew translate(llvm::IntegerType *type) {
+    return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
+  }
+
+  /// Translates the given pointer type.
+  LLVM::LLVMTypeNew translate(llvm::PointerType *type) {
+    return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
+                                      type->getAddressSpace());
+  }
+
+  /// Translates the given structure type.
+  LLVM::LLVMTypeNew translate(llvm::StructType *type) {
+    SmallVector<LLVM::LLVMTypeNew, 8> subtypes;
+    if (type->isLiteral()) {
+      translateTypes(type->subtypes(), subtypes);
+      return LLVM::LLVMStructType::getLiteral(&context, subtypes,
+                                              type->isPacked());
+    }
+
+    if (type->isOpaque())
+      return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
+
+    LLVM::LLVMStructType translated =
+        LLVM::LLVMStructType::getIdentified(&context, type->getName());
+    knownTranslations.try_emplace(type, translated);
+    translateTypes(type->subtypes(), subtypes);
+    LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
+    assert(succeeded(bodySet) &&
+           "could not set the body of an identified struct");
+    (void)bodySet;
+    return translated;
+  }
+
+  /// Translates the given fixed-vector type.
+  LLVM::LLVMTypeNew translate(llvm::FixedVectorType *type) {
+    return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
+                                          type->getNumElements());
+  }
+
+  /// Translates the given scalable-vector type.
+  LLVM::LLVMTypeNew translate(llvm::ScalableVectorType *type) {
+    return LLVM::LLVMScalableVectorType::get(
+        translateType(type->getElementType()), type->getMinNumElements());
+  }
+
+  /// Translates a list of types.
+  void translateTypes(ArrayRef<llvm::Type *> types,
+                      SmallVectorImpl<LLVM::LLVMTypeNew> &result) {
+    result.reserve(result.size() + types.size());
+    for (llvm::Type *type : types)
+      result.push_back(translateType(type));
+  }
+
+  /// Map of known translations. Serves as a cache and as recursion stopper for
+  /// translating recursive structs.
+  llvm::DenseMap<llvm::Type *, LLVM::LLVMTypeNew> knownTranslations;
+
+  /// The context in which MLIR types are created.
+  MLIRContext &context;
+};
+} // end namespace
+
+/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended
+/// exclusively for testing.
+LLVM::LLVMTypeNew mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type,
+                                                      MLIRContext &context) {
+  return TypeFromLLVMIRTranslator(context).translateType(type);
+}

diff  --git a/mlir/test/Target/llvmir-types.mlir b/mlir/test/Target/llvmir-types.mlir
new file mode 100644
index 000000000000..d807562d1a2d
--- /dev/null
+++ b/mlir/test/Target/llvmir-types.mlir
@@ -0,0 +1,228 @@
+// RUN: mlir-translate -test-mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @primitives() {
+  // CHECK: declare void @return_void()
+  // CHECK: declare void @return_void_round()
+  "llvm.test_introduce_func"() { name = "return_void", type = !llvm2.void } : () -> ()
+  // CHECK: declare half @return_half()
+  // CHECK: declare half @return_half_round()
+  "llvm.test_introduce_func"() { name = "return_half", type = !llvm2.half } : () -> ()
+  // CHECK: declare bfloat @return_bfloat()
+  // CHECK: declare bfloat @return_bfloat_round()
+  "llvm.test_introduce_func"() { name = "return_bfloat", type = !llvm2.bfloat } : () -> ()
+  // CHECK: declare float @return_float()
+  // CHECK: declare float @return_float_round()
+  "llvm.test_introduce_func"() { name = "return_float", type = !llvm2.float } : () -> ()
+  // CHECK: declare double @return_double()
+  // CHECK: declare double @return_double_round()
+  "llvm.test_introduce_func"() { name = "return_double", type = !llvm2.double } : () -> ()
+  // CHECK: declare fp128 @return_fp128()
+  // CHECK: declare fp128 @return_fp128_round()
+  "llvm.test_introduce_func"() { name = "return_fp128", type = !llvm2.fp128 } : () -> ()
+  // CHECK: declare x86_fp80 @return_x86_fp80()
+  // CHECK: declare x86_fp80 @return_x86_fp80_round()
+  "llvm.test_introduce_func"() { name = "return_x86_fp80", type = !llvm2.x86_fp80 } : () -> ()
+  // CHECK: declare ppc_fp128 @return_ppc_fp128()
+  // CHECK: declare ppc_fp128 @return_ppc_fp128_round()
+  "llvm.test_introduce_func"() { name = "return_ppc_fp128", type = !llvm2.ppc_fp128 } : () -> ()
+  // CHECK: declare x86_mmx @return_x86_mmx()
+  // CHECK: declare x86_mmx @return_x86_mmx_round()
+  "llvm.test_introduce_func"() { name = "return_x86_mmx", type = !llvm2.x86_mmx } : () -> ()
+  llvm.return
+}
+
+llvm.func @funcs() {
+  // CHECK: declare void @f_void_i32(i32)
+  // CHECK: declare void @f_void_i32_round(i32)
+  "llvm.test_introduce_func"() { name ="f_void_i32", type = !llvm2.func<void (i32)> } : () -> ()
+  // CHECK: declare i32 @f_i32_empty()
+  // CHECK: declare i32 @f_i32_empty_round()
+  "llvm.test_introduce_func"() { name ="f_i32_empty", type = !llvm2.func<i32 ()> } : () -> ()
+  // CHECK: declare i32 @f_i32_half_bfloat_float_double(half, bfloat, float, double)
+  // CHECK: declare i32 @f_i32_half_bfloat_float_double_round(half, bfloat, float, double)
+  "llvm.test_introduce_func"() { name ="f_i32_half_bfloat_float_double", type = !llvm2.func<i32 (half, bfloat, float, double)> } : () -> ()
+  // CHECK: declare i32 @f_i32_i32_i32(i32, i32)
+  // CHECK: declare i32 @f_i32_i32_i32_round(i32, i32)
+  "llvm.test_introduce_func"() { name ="f_i32_i32_i32", type = !llvm2.func<i32 (i32, i32)> } : () -> ()
+  // CHECK: declare void @f_void_variadic(...)
+  // CHECK: declare void @f_void_variadic_round(...)
+  "llvm.test_introduce_func"() { name ="f_void_variadic", type = !llvm2.func<void (...)> } : () -> ()
+  // CHECK: declare void @f_void_i32_i32_variadic(i32, i32, ...)
+  // CHECK: declare void @f_void_i32_i32_variadic_round(i32, i32, ...)
+  "llvm.test_introduce_func"() { name ="f_void_i32_i32_variadic", type = !llvm2.func<void (i32, i32, ...)> } : () -> ()
+  llvm.return
+}
+
+llvm.func @ints() {
+  // CHECK: declare i1 @return_i1()
+  // CHECK: declare i1 @return_i1_round()
+  "llvm.test_introduce_func"() { name = "return_i1", type = !llvm2.i1 } : () -> ()
+  // CHECK: declare i8 @return_i8()
+  // CHECK: declare i8 @return_i8_round()
+  "llvm.test_introduce_func"() { name = "return_i8", type = !llvm2.i8 } : () -> ()
+  // CHECK: declare i16 @return_i16()
+  // CHECK: declare i16 @return_i16_round()
+  "llvm.test_introduce_func"() { name = "return_i16", type = !llvm2.i16 } : () -> ()
+  // CHECK: declare i32 @return_i32()
+  // CHECK: declare i32 @return_i32_round()
+  "llvm.test_introduce_func"() { name = "return_i32", type = !llvm2.i32 } : () -> ()
+  // CHECK: declare i64 @return_i64()
+  // CHECK: declare i64 @return_i64_round()
+  "llvm.test_introduce_func"() { name = "return_i64", type = !llvm2.i64 } : () -> ()
+  // CHECK: declare i57 @return_i57()
+  // CHECK: declare i57 @return_i57_round()
+  "llvm.test_introduce_func"() { name = "return_i57", type = !llvm2.i57 } : () -> ()
+  // CHECK: declare i129 @return_i129()
+  // CHECK: declare i129 @return_i129_round()
+  "llvm.test_introduce_func"() { name = "return_i129", type = !llvm2.i129 } : () -> ()
+  llvm.return
+}
+
+llvm.func @pointers() {
+  // CHECK: declare i8* @return_pi8()
+  // CHECK: declare i8* @return_pi8_round()
+  "llvm.test_introduce_func"() { name = "return_pi8", type = !llvm2.ptr<i8> } : () -> ()
+  // CHECK: declare float* @return_pfloat()
+  // CHECK: declare float* @return_pfloat_round()
+  "llvm.test_introduce_func"() { name = "return_pfloat", type = !llvm2.ptr<float> } : () -> ()
+  // CHECK: declare i8** @return_ppi8()
+  // CHECK: declare i8** @return_ppi8_round()
+  "llvm.test_introduce_func"() { name = "return_ppi8", type = !llvm2.ptr<ptr<i8>> } : () -> ()
+  // CHECK: declare i8***** @return_pppppi8()
+  // CHECK: declare i8***** @return_pppppi8_round()
+  "llvm.test_introduce_func"() { name = "return_pppppi8", type = !llvm2.ptr<ptr<ptr<ptr<ptr<i8>>>>> } : () -> ()
+  // CHECK: declare i8* @return_pi8_0()
+  // CHECK: declare i8* @return_pi8_0_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_0", type = !llvm2.ptr<i8, 0> } : () -> ()
+  // CHECK: declare i8 addrspace(1)* @return_pi8_1()
+  // CHECK: declare i8 addrspace(1)* @return_pi8_1_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_1", type = !llvm2.ptr<i8, 1> } : () -> ()
+  // CHECK: declare i8 addrspace(42)* @return_pi8_42()
+  // CHECK: declare i8 addrspace(42)* @return_pi8_42_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_42", type = !llvm2.ptr<i8, 42> } : () -> ()
+  // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9()
+  // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9_round()
+  "llvm.test_introduce_func"() { name = "return_ppi8_42_9", type = !llvm2.ptr<ptr<i8, 42>, 9> } : () -> ()
+  llvm.return
+}
+
+llvm.func @vectors() {
+  // CHECK: declare <4 x i32> @return_v4_i32()
+  // CHECK: declare <4 x i32> @return_v4_i32_round()
+  "llvm.test_introduce_func"() { name = "return_v4_i32", type = !llvm2.vec<4 x i32> } : () -> ()
+  // CHECK: declare <4 x float> @return_v4_float()
+  // CHECK: declare <4 x float> @return_v4_float_round()
+  "llvm.test_introduce_func"() { name = "return_v4_float", type = !llvm2.vec<4 x float> } : () -> ()
+  // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
+  // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32_round()
+  "llvm.test_introduce_func"() { name = "return_vs_4_i32", type = !llvm2.vec<? x 4 x i32> } : () -> ()
+  // CHECK: declare <vscale x 8 x half> @return_vs_8_half()
+  // CHECK: declare <vscale x 8 x half> @return_vs_8_half_round()
+  "llvm.test_introduce_func"() { name = "return_vs_8_half", type = !llvm2.vec<? x 8 x half> } : () -> ()
+  // CHECK: declare <4 x i8*> @return_v_4_pi8()
+  // CHECK: declare <4 x i8*> @return_v_4_pi8_round()
+  "llvm.test_introduce_func"() { name = "return_v_4_pi8", type = !llvm2.vec<4 x ptr<i8>> } : () -> ()
+  llvm.return
+}
+
+llvm.func @arrays() {
+  // CHECK: declare [10 x i32] @return_a10_i32()
+  // CHECK: declare [10 x i32] @return_a10_i32_round()
+  "llvm.test_introduce_func"() { name = "return_a10_i32", type = !llvm2.array<10 x i32> } : () -> ()
+  // CHECK: declare [8 x float] @return_a8_float()
+  // CHECK: declare [8 x float] @return_a8_float_round()
+  "llvm.test_introduce_func"() { name = "return_a8_float", type = !llvm2.array<8 x float> } : () -> ()
+  // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4()
+  // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4_round()
+  "llvm.test_introduce_func"() { name = "return_a10_pi32_4", type = !llvm2.array<10 x ptr<i32, 4>> } : () -> ()
+  // CHECK: declare [10 x [4 x float]] @return_a10_a4_float()
+  // CHECK: declare [10 x [4 x float]] @return_a10_a4_float_round()
+  "llvm.test_introduce_func"() { name = "return_a10_a4_float", type = !llvm2.array<10 x array<4 x float>> } : () -> ()
+  llvm.return
+}
+
+llvm.func @literal_structs() {
+  // CHECK: declare {} @return_struct_empty()
+  // CHECK: declare {} @return_struct_empty_round()
+  "llvm.test_introduce_func"() { name = "return_struct_empty", type = !llvm2.struct<()> } : () -> ()
+  // CHECK: declare { i32 } @return_s_i32()
+  // CHECK: declare { i32 } @return_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_i32", type = !llvm2.struct<(i32)> } : () -> ()
+  // CHECK: declare { float, i32 } @return_s_float_i32()
+  // CHECK: declare { float, i32 } @return_s_float_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_float_i32", type = !llvm2.struct<(float, i32)> } : () -> ()
+  // CHECK: declare { { i32 } } @return_s_s_i32()
+  // CHECK: declare { { i32 } } @return_s_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_s_i32", type = !llvm2.struct<(struct<(i32)>)> } : () -> ()
+  // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float()
+  // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float_round()
+  "llvm.test_introduce_func"() { name = "return_s_i32_s_i32_float", type = !llvm2.struct<(i32, struct<(i32)>, float)> } : () -> ()
+
+  // CHECK: declare <{}> @return_sp_empty()
+  // CHECK: declare <{}> @return_sp_empty_round()
+  "llvm.test_introduce_func"() { name = "return_sp_empty", type = !llvm2.struct<packed ()> } : () -> ()
+  // CHECK: declare <{ i32 }> @return_sp_i32()
+  // CHECK: declare <{ i32 }> @return_sp_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_i32", type = !llvm2.struct<packed (i32)> } : () -> ()
+  // CHECK: declare <{ float, i32 }> @return_sp_float_i32()
+  // CHECK: declare <{ float, i32 }> @return_sp_float_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_float_i32", type = !llvm2.struct<packed (float, i32)> } : () -> ()
+  // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float()
+  // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float_round()
+  "llvm.test_introduce_func"() { name = "return_sp_i32_s_i31_1_float", type = !llvm2.struct<packed (i32, struct<(i32, i1)>, float)> } : () -> ()
+
+  // CHECK: declare { <{ i32 }> } @return_s_sp_i32()
+  // CHECK: declare { <{ i32 }> } @return_s_sp_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_sp_i32", type = !llvm2.struct<(struct<packed (i32)>)> } : () -> ()
+  // CHECK: declare <{ { i32 } }> @return_sp_s_i32()
+  // CHECK: declare <{ { i32 } }> @return_sp_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_s_i32", type = !llvm2.struct<packed (struct<(i32)>)> } : () -> ()
+  llvm.return
+}
+
+// -----
+// Put structs into a separate split so that we can match their declarations
+// locally.
+
+// CHECK: %empty = type {}
+// CHECK: %opaque = type opaque
+// CHECK: %long = type { i32, { i32, i1 }, float, void ()* }
+// CHECK: %self-recursive = type { %self-recursive* }
+// CHECK: %unpacked = type { i32 }
+// CHECK: %packed = type <{ i32 }>
+// CHECK: %"name with spaces and !^$@$#" = type <{ i32 }>
+// CHECK: %mutually-a = type { %mutually-b* }
+// CHECK: %mutually-b = type { %mutually-a addrspace(3)* }
+// CHECK: %struct-of-arrays = type { [10 x i32] }
+// CHECK: %array-of-structs = type { i32 }
+// CHECK: %ptr-to-struct = type { i8 }
+
+llvm.func @identified_structs() {
+  // CHECK: declare %empty
+  "llvm.test_introduce_func"() { name = "return_s_empty", type = !llvm2.struct<"empty", ()> } : () -> ()
+  // CHECK: declare %opaque
+  "llvm.test_introduce_func"() { name = "return_s_opaque", type = !llvm2.struct<"opaque", opaque> } : () -> ()
+  // CHECK: declare %long
+  "llvm.test_introduce_func"() { name = "return_s_long", type = !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr<func<void ()>>)> } : () -> ()
+  // CHECK: declare %self-recursive
+  "llvm.test_introduce_func"() { name = "return_s_self_recurisve", type = !llvm2.struct<"self-recursive", (ptr<struct<"self-recursive">>)> } : () -> ()
+  // CHECK: declare %unpacked
+  "llvm.test_introduce_func"() { name = "return_s_unpacked", type = !llvm2.struct<"unpacked", (i32)> } : () -> ()
+  // CHECK: declare %packed
+  "llvm.test_introduce_func"() { name = "return_s_packed", type = !llvm2.struct<"packed", packed (i32)> } : () -> ()
+  // CHECK: declare %"name with spaces and !^$@$#"
+  "llvm.test_introduce_func"() { name = "return_s_symbols", type = !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> } : () -> ()
+
+  // CHECK: declare %mutually-a
+  "llvm.test_introduce_func"() { name = "return_s_mutually_a", type = !llvm2.struct<"mutually-a", (ptr<struct<"mutually-b", (ptr<struct<"mutually-a">, 3>)>>)> } : () -> ()
+  // CHECK: declare %mutually-b
+  "llvm.test_introduce_func"() { name = "return_s_mutually_b", type = !llvm2.struct<"mutually-b", (ptr<struct<"mutually-a", (ptr<struct<"mutually-b">>)>, 3>)> } : () -> ()
+
+  // CHECK: declare %struct-of-arrays
+  "llvm.test_introduce_func"() { name = "return_s_struct_of_arrays", type = !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> } : () -> ()
+  // CHECK: declare [10 x %array-of-structs]
+  "llvm.test_introduce_func"() { name = "return_s_array_of_structs", type = !llvm2.array<10 x struct<"array-of-structs", (i32)>> } : () -> ()
+  // CHECK: declare %ptr-to-struct*
+  "llvm.test_introduce_func"() { name = "return_s_ptr_to_struct", type = !llvm2.ptr<struct<"ptr-to-struct", (i8)>> } : () -> ()
+  llvm.return
+}

diff  --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt
index 0df357c8c355..ec9e5cd99801 100644
--- a/mlir/test/lib/CMakeLists.txt
+++ b/mlir/test/lib/CMakeLists.txt
@@ -2,4 +2,5 @@ add_subdirectory(Dialect)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
+add_subdirectory(Target)
 add_subdirectory(Transforms)

diff  --git a/mlir/test/lib/Target/CMakeLists.txt b/mlir/test/lib/Target/CMakeLists.txt
new file mode 100644
index 000000000000..cb8f206469ae
--- /dev/null
+++ b/mlir/test/lib/Target/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTestLLVMTypeTranslation
+  TestLLVMTypeTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+  TransformUtils
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMIR
+  MLIRTargetLLVMIRModuleTranslation
+  MLIRTestIR
+  MLIRTranslation
+  )

diff  --git a/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp
new file mode 100644
index 000000000000..b76ac2a13344
--- /dev/null
+++ b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp
@@ -0,0 +1,79 @@
+//===- TestLLVMTypeTranslation.cpp - Test MLIR/LLVM IR type translation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Target/LLVMIR/TypeTranslation.h"
+#include "mlir/Translation.h"
+
+using namespace mlir;
+
+namespace {
+class TestLLVMTypeTranslation : public LLVM::ModuleTranslation {
+  // Allow access to the constructors under MSVC.
+  friend LLVM::ModuleTranslation;
+
+public:
+  using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+  /// Simple test facility for translating types from MLIR LLVM dialect to LLVM
+  /// IR. This converts the "llvm.test_introduce_func" operation into an LLVM IR
+  /// function with the name extracted from the `name` attribute that returns
+  /// the type contained in the `type` attribute if it is a non-function type or
+  /// that has the signature obtained by converting `type` if it is a function
+  /// type. This is a temporary check before type translation is substituted
+  /// into the main translation flow and exercised here.
+  LogicalResult convertOperation(Operation &op,
+                                 llvm::IRBuilder<> &builder) override {
+    if (op.getName().getStringRef() == "llvm.test_introduce_func") {
+      auto attr = op.getAttrOfType<TypeAttr>("type");
+      assert(attr && "expected 'type' attribute");
+      auto type = attr.getValue().cast<LLVM::LLVMTypeNew>();
+
+      auto nameAttr = op.getAttrOfType<StringAttr>("name");
+      assert(nameAttr && "expected 'name' attributes");
+
+      llvm::Type *translated =
+          LLVM::translateTypeToLLVMIR(type, builder.getContext());
+
+      llvm::Module *module = builder.GetInsertBlock()->getModule();
+      if (auto *funcType = dyn_cast<llvm::FunctionType>(translated))
+        module->getOrInsertFunction(nameAttr.getValue(), funcType);
+      else
+        module->getOrInsertFunction(nameAttr.getValue(), translated);
+
+      std::string roundtripName = (Twine(nameAttr.getValue()) + "_round").str();
+      LLVM::LLVMTypeNew translatedBack =
+          LLVM::translateTypeFromLLVMIR(translated, *op.getContext());
+      llvm::Type *translatedBackAndForth =
+          LLVM::translateTypeToLLVMIR(translatedBack, builder.getContext());
+      if (auto *funcType = dyn_cast<llvm::FunctionType>(translatedBackAndForth))
+        module->getOrInsertFunction(roundtripName, funcType);
+      else
+        module->getOrInsertFunction(roundtripName, translatedBackAndForth);
+      return success();
+    }
+
+    return LLVM::ModuleTranslation::convertOperation(op, builder);
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestLLVMTypeTranslation() {
+  TranslateFromMLIRRegistration reg(
+      "test-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+        std::unique_ptr<llvm::Module> llvmModule =
+            LLVM::ModuleTranslation::translateModule<TestLLVMTypeTranslation>(
+                module.getOperation());
+        llvmModule->print(output, nullptr);
+        return success();
+      });
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt
index 897e7adc03bd..1e6cdfe0f3b1 100644
--- a/mlir/tools/mlir-translate/CMakeLists.txt
+++ b/mlir/tools/mlir-translate/CMakeLists.txt
@@ -13,7 +13,11 @@ target_link_libraries(mlir-translate
   PRIVATE
   ${dialect_libs}
   ${translation_libs}
+  ${test_libs}
   MLIRIR
+  # TODO: remove after LLVM dialect transition is complete; translation uses a
+  # registration function defined in this library unconditionally.
+  MLIRLLVMTypeTestDialect
   MLIRParser
   MLIRPass
   MLIRSPIRV

diff  --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 914bd340b3f5..70bf285112a4 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -49,17 +49,21 @@ static llvm::cl::opt<bool> verifyDiagnostics(
 
 namespace mlir {
 // Defined in the test directory, no public header.
+void registerLLVMTypeTestDialect();
+void registerTestLLVMTypeTranslation();
 void registerTestRoundtripSPIRV();
 void registerTestRoundtripDebugSPIRV();
 } // namespace mlir
 
 static void registerTestTranslations() {
+  registerTestLLVMTypeTranslation();
   registerTestRoundtripSPIRV();
   registerTestRoundtripDebugSPIRV();
 }
 
 int main(int argc, char **argv) {
   registerAllDialects();
+  registerLLVMTypeTestDialect();
   registerAllTranslations();
   registerTestTranslations();
   llvm::InitLLVM y(argc, argv);


        


More information about the Mlir-commits mailing list