[Mlir-commits] [mlir] fc7f726 - [mlir][LLVMIR] Memorize compatible LLVM types

Min-Yih Hsu llvmlistbot at llvm.org
Mon Jun 27 09:47:50 PDT 2022


Author: Min-Yih Hsu
Date: 2022-06-27T09:46:41-07:00
New Revision: fc7f7260a609b2092260b6a5e867a20cf364808c

URL: https://github.com/llvm/llvm-project/commit/fc7f7260a609b2092260b6a5e867a20cf364808c
DIFF: https://github.com/llvm/llvm-project/commit/fc7f7260a609b2092260b6a5e867a20cf364808c.diff

LOG: [mlir][LLVMIR] Memorize compatible LLVM types

This patch memorize compatible LLVM types in `LLVM::isCompatibleType` in
order to avoid redundant works.

This is especially useful when the size of program is big and there are
multiple occurrences of some deeply nested LLVM struct types, in which
case we can gain quite some speedups with this patch.

Differential Revision: https://reviews.llvm.org/D127918

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d97a267f4aba..d9b29e0e306c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -25,6 +25,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 9ede9fd5931f..37e474462d3d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -61,6 +61,15 @@ def LLVM_Dialect : Dialect {
     static StringRef getEmitCWrapperAttrName() {
       return "llvm.emit_c_interface";
     }
+
+    /// Returns `true` if the given type is compatible with the LLVM dialect.
+    static bool isCompatibleType(Type);
+
+  private:
+    /// A cache storing compatible LLVM types that have been verified. This
+    /// can save us lots of verification time if there are many occurrences
+    /// of some deeply-nested aggregate types in the program.
+    ThreadLocalCache<DenseSet<Type>> compatibleTypes;
   }];
 
   let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 9c975289c18c..c3a244f021c7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -451,7 +451,8 @@ void printType(Type type, AsmPrinter &printer);
 // Utility functions.
 //===----------------------------------------------------------------------===//
 
-/// Returns `true` if the given type is compatible with the LLVM dialect.
+/// Returns `true` if the given type is compatible with the LLVM dialect. This
+/// is an alias to `LLVMDialect::isCompatibleType`.
 bool isCompatibleType(Type type);
 
 /// Returns `true` if the given outer type is compatible with the LLVM dialect

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 4dbfc8293a05..b02d53a2efae 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -721,63 +721,75 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) {
   return false;
 }
 
-static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
-  if (callstack.contains(type))
+static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
+  if (!compatibleTypes.insert(type).second)
     return true;
 
-  callstack.insert(type);
-  auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });
-
   auto isCompatible = [&](Type type) {
-    return isCompatibleImpl(type, callstack);
+    return isCompatibleImpl(type, compatibleTypes);
   };
 
-  return llvm::TypeSwitch<Type, bool>(type)
-      .Case<LLVMStructType>([&](auto structType) {
-        return llvm::all_of(structType.getBody(), isCompatible);
-      })
-      .Case<LLVMFunctionType>([&](auto funcType) {
-        return isCompatible(funcType.getReturnType()) &&
-               llvm::all_of(funcType.getParams(), isCompatible);
-      })
-      .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
-      .Case<VectorType>([&](auto vecType) {
-        return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
-      })
-      .Case<LLVMPointerType>([&](auto pointerType) {
-        if (pointerType.isOpaque())
-          return true;
-        return isCompatible(pointerType.getElementType());
-      })
-      // clang-format off
-      .Case<
-          LLVMFixedVectorType,
-          LLVMScalableVectorType,
-          LLVMArrayType
-      >([&](auto containerType) {
-        return isCompatible(containerType.getElementType());
-      })
-      .Case<
-        BFloat16Type,
-        Float16Type,
-        Float32Type,
-        Float64Type,
-        Float80Type,
-        Float128Type,
-        LLVMLabelType,
-        LLVMMetadataType,
-        LLVMPPCFP128Type,
-        LLVMTokenType,
-        LLVMVoidType,
-        LLVMX86MMXType
-      >([](Type) { return true; })
-      // clang-format on
-      .Default([](Type) { return false; });
+  bool result =
+      llvm::TypeSwitch<Type, bool>(type)
+          .Case<LLVMStructType>([&](auto structType) {
+            return llvm::all_of(structType.getBody(), isCompatible);
+          })
+          .Case<LLVMFunctionType>([&](auto funcType) {
+            return isCompatible(funcType.getReturnType()) &&
+                   llvm::all_of(funcType.getParams(), isCompatible);
+          })
+          .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
+          .Case<VectorType>([&](auto vecType) {
+            return vecType.getRank() == 1 &&
+                   isCompatible(vecType.getElementType());
+          })
+          .Case<LLVMPointerType>([&](auto pointerType) {
+            if (pointerType.isOpaque())
+              return true;
+            return isCompatible(pointerType.getElementType());
+          })
+          // clang-format off
+          .Case<
+              LLVMFixedVectorType,
+              LLVMScalableVectorType,
+              LLVMArrayType
+          >([&](auto containerType) {
+            return isCompatible(containerType.getElementType());
+          })
+          .Case<
+            BFloat16Type,
+            Float16Type,
+            Float32Type,
+            Float64Type,
+            Float80Type,
+            Float128Type,
+            LLVMLabelType,
+            LLVMMetadataType,
+            LLVMPPCFP128Type,
+            LLVMTokenType,
+            LLVMVoidType,
+            LLVMX86MMXType
+          >([](Type) { return true; })
+          // clang-format on
+          .Default([](Type) { return false; });
+
+  if (!result)
+    compatibleTypes.erase(type);
+
+  return result;
+}
+
+bool LLVMDialect::isCompatibleType(Type type) {
+  if (auto *llvmDialect =
+          type.getContext()->getLoadedDialect<LLVM::LLVMDialect>())
+    return isCompatibleImpl(type, llvmDialect->compatibleTypes.get());
+
+  DenseSet<Type> localCompatibleTypes;
+  return isCompatibleImpl(type, localCompatibleTypes);
 }
 
 bool mlir::LLVM::isCompatibleType(Type type) {
-  SetVector<Type> callstack;
-  return isCompatibleImpl(type, callstack);
+  return LLVMDialect::isCompatibleType(type);
 }
 
 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {


        


More information about the Mlir-commits mailing list