[Mlir-commits] [mlir] d64b3e4 - [mlir] Avoid needlessly converting LLVM named structs with compatible elements

Alex Zinenko llvmlistbot at llvm.org
Mon Dec 6 04:42:19 PST 2021


Author: Alex Zinenko
Date: 2021-12-06T13:42:11+01:00
New Revision: d64b3e47ba6347ef4c68c0666a90eda8f986f525

URL: https://github.com/llvm/llvm-project/commit/d64b3e47ba6347ef4c68c0666a90eda8f986f525
DIFF: https://github.com/llvm/llvm-project/commit/d64b3e47ba6347ef4c68c0666a90eda8f986f525.diff

LOG: [mlir] Avoid needlessly converting LLVM named structs with compatible elements

Conversion of LLVM named structs leads to them being renamed since we cannot
modify the body of the struct type once it is set. Previously, this applied to
all named struct types, even if their element types were not affected by the
conversion. Make this behvaior only applicable when element types are changed.
This requires making the LLVM dialect type-compatibility check recursively look
at the element types (arguably, it should have been doing than since the moment
the LLVM dialect type system stopped being closed). In addition, have a more
lax check for outer types only to avoid repeated check when necessary (e.g.,
parser, verifiers that are going to also look at the inner type).

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D115037

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/test/Conversion/StandardToLLVM/convert-types.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 3de4f3ab60db..6abede4b8c55 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -429,6 +429,10 @@ void printType(Type type, AsmPrinter &printer);
 /// Returns `true` if the given type is compatible with the LLVM dialect.
 bool isCompatibleType(Type type);
 
+/// Returns `true` if the given outer type is compatible with the LLVM dialect
+/// without checking its potential nested types such as struct elements.
+bool isCompatibleOuterType(Type type);
+
 /// Returns `true` if the given type is a floating-point type compatible with
 /// the LLVM dialect.
 bool isCompatibleFloatingPointType(Type type);

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cd6651cbcf6e..5175b93a9303 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -55,6 +55,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
                     ArrayRef<Type> callStack) -> llvm::Optional<LogicalResult> {
+    // Fastpath for types that won't be converted by this callback anyway.
+    if (LLVM::isCompatibleType(type)) {
+      results.push_back(type);
+      return success();
+    }
+
     if (type.isIdentified()) {
       auto convertedType = LLVM::LLVMStructType::getIdentified(
           type.getContext(), ("_Converted_" + type.getName()).str());

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index ca1bbbf59a69..bca0a5800d91 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -468,7 +468,7 @@ Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
   Type type = dispatchParse(parser, /*allowAny=*/false);
   if (!type)
     return type;
-  if (!isCompatibleType(type)) {
+  if (!isCompatibleOuterType(type)) {
     parser.emitError(loc) << "unexpected type, expected keyword";
     return nullptr;
   }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ea4d7a69c063..1ce9e4481e9c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -1,4 +1,3 @@
-//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,6 +18,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeSupport.h"
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/TypeSize.h"
 
@@ -120,9 +120,10 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
 //===----------------------------------------------------------------------===//
 
 bool LLVMPointerType::isValidElementType(Type type) {
-  return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
-                                            LLVMMetadataType, LLVMLabelType>()
-                                : type.isa<PointerElementTypeInterface>();
+  return isCompatibleOuterType(type)
+             ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
+                         LLVMLabelType>()
+             : type.isa<PointerElementTypeInterface>();
 }
 
 LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@@ -483,17 +484,9 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
 // Utility functions.
 //===----------------------------------------------------------------------===//
 
-bool mlir::LLVM::isCompatibleType(Type type) {
-  // Only signless integers are compatible.
-  if (auto intType = type.dyn_cast<IntegerType>())
-    return intType.isSignless();
-
-  // 1D vector types are compatible if their element types are.
-  if (auto vecType = type.dyn_cast<VectorType>())
-    return vecType.getRank() == 1 && isCompatibleType(vecType.getElementType());
-
+bool mlir::LLVM::isCompatibleOuterType(Type type) {
   // clang-format off
-  return type.isa<
+  if (type.isa<
       BFloat16Type,
       Float16Type,
       Float32Type,
@@ -512,8 +505,75 @@ bool mlir::LLVM::isCompatibleType(Type type) {
       LLVMScalableVectorType,
       LLVMVoidType,
       LLVMX86MMXType
-  >();
-  // clang-format on
+    >()) {
+    // clang-format on
+    return true;
+  }
+
+  // Only signless integers are compatible.
+  if (auto intType = type.dyn_cast<IntegerType>())
+    return intType.isSignless();
+
+  // 1D vector types are compatible.
+  if (auto vecType = type.dyn_cast<VectorType>())
+    return vecType.getRank() == 1;
+
+  return false;
+}
+
+static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
+  if (callstack.contains(type))
+    return true;
+
+  callstack.insert(type);
+  auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });
+
+  auto isCompatible = [&](Type type) {
+    return isCompatibleImpl(type, callstack);
+  };
+
+  return llvm::TypeSwitch<Type, bool>(type)
+      .Case<LLVMStructType>([&](auto structType) {
+        return llvm::all_of(structType.getBody(), isCompatible);
+      })
+      .Case<LLVMFunctionType>([&](auto funcType) {
+        return isCompatible(funcType.getReturnType()) &&
+               llvm::all_of(funcType.getParams(), isCompatible);
+      })
+      .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
+      .Case<VectorType>([&](auto vecType) {
+        return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
+      })
+      // clang-format off
+      .Case<
+          LLVMPointerType,
+          LLVMFixedVectorType,
+          LLVMScalableVectorType,
+          LLVMArrayType
+      >([&](auto containerType) {
+        return isCompatible(containerType.getElementType());
+      })
+      .Case<
+        BFloat16Type,
+        Float16Type,
+        Float32Type,
+        Float64Type,
+        Float80Type,
+        Float128Type,
+        LLVMLabelType,
+        LLVMMetadataType,
+        LLVMPPCFP128Type,
+        LLVMTokenType,
+        LLVMVoidType,
+        LLVMX86MMXType
+      >([](Type) { return true; })
+      // clang-format on
+      .Default([](Type) { return false; });
+}
+
+bool mlir::LLVM::isCompatibleType(Type type) {
+  SetVector<Type> callstack;
+  return isCompatibleImpl(type, callstack);
 }
 
 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir
index 298ba1a6a099..8c214a8f03e9 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-types.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-types.mlir
@@ -16,6 +16,10 @@ func private @struct_ptr() -> !llvm.struct<(ptr<!test.smpla>)>
 // CHECK: !llvm.struct<"_Converted_named", (ptr<i42>)>
 func private @named_struct_ptr() -> !llvm.struct<"named", (ptr<!test.smpla>)>
 
+// CHECK-LABEL: @named_no_convert
+// CHECK: !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
+func private @named_no_convert() -> !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
+
 // CHECK-LABEL: @array_ptr()
 // CHECK: !llvm.array<10 x ptr<i42>> 
 func private @array_ptr() -> !llvm.array<10 x ptr<!test.smpla>>


        


More information about the Mlir-commits mailing list