[Mlir-commits] [mlir] [mlir] Fix use-after-free in #117513 (PR #120968)

Matthias Springer llvmlistbot at llvm.org
Mon Dec 23 06:07:23 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/120968

Fix a use-after-free in #117513. Free-standing lambdas should not be defined inside of the `LLVMTypeConverter` constructor because they go out of scope.


>From 935d441c9e7353e8c91e0dfd75b1418ad9c3af9d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 23 Dec 2024 15:04:45 +0100
Subject: [PATCH] [mlir] Fix use-after-free in #117513

---
 .../Conversion/LLVMCommon/TypeConverter.h     |  70 ++++----
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 160 ++++++++++--------
 2 files changed, 123 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..38b5e492a8ed8f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter {
   /// Check if a memref type can be converted to a bare pointer.
   static bool canConvertToBarePtr(BaseMemRefType type);
 
+  /// Convert a memref type into a list of LLVM IR types that will form the
+  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
+  /// arrays in the descriptors are unpacked to individual index-typed elements,
+  /// else they are kept as rank-sized arrays of index type. In particular,
+  /// the list will contain:
+  /// - two pointers to the memref element type, followed by
+  /// - an index-typed offset, followed by
+  /// - (if unpackAggregates = true)
+  ///    - one index-typed size per dimension of the memref, followed by
+  ///    - one index-typed stride per dimension of the memref.
+  /// - (if unpackArrregates = false)
+  ///   - one rank-sized array of index-type for the size of each dimension
+  ///   - one rank-sized array of index-type for the stride of each dimension
+  ///
+  /// For example, memref<?x?xf32> is converted to the following list:
+  /// - `!llvm<"float*">` (allocated pointer),
+  /// - `!llvm<"float*">` (aligned pointer),
+  /// - `i64` (offset),
+  /// - `i64`, `i64` (sizes),
+  /// - `i64`, `i64` (strides).
+  /// These types can be recomposed to a memref descriptor struct.
+  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+                                                 bool unpackAggregates) const;
+
+  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+  /// that will form the unranked memref descriptor. In particular, this list
+  /// contains:
+  /// - an integer rank, followed by
+  /// - a pointer to the memref descriptor struct.
+  /// For example, memref<*xf32> is converted to the following list:
+  /// i64 (rank)
+  /// !llvm<"i8*"> (type-erased pointer).
+  /// These types can be recomposed to a unranked memref descriptor struct.
+  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
+
 protected:
   /// Pointer to the LLVM dialect.
   LLVM::LLVMDialect *llvmDialect;
@@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a memref type into an LLVM type that captures the relevant data.
   Type convertMemRefType(MemRefType type) const;
 
-  /// Convert a memref type into a list of LLVM IR types that will form the
-  /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
-  /// arrays in the descriptors are unpacked to individual index-typed elements,
-  /// else they are kept as rank-sized arrays of index type. In particular,
-  /// the list will contain:
-  /// - two pointers to the memref element type, followed by
-  /// - an index-typed offset, followed by
-  /// - (if unpackAggregates = true)
-  ///    - one index-typed size per dimension of the memref, followed by
-  ///    - one index-typed stride per dimension of the memref.
-  /// - (if unpackArrregates = false)
-  ///   - one rank-sized array of index-type for the size of each dimension
-  ///   - one rank-sized array of index-type for the stride of each dimension
-  ///
-  /// For example, memref<?x?xf32> is converted to the following list:
-  /// - `!llvm<"float*">` (allocated pointer),
-  /// - `!llvm<"float*">` (aligned pointer),
-  /// - `i64` (offset),
-  /// - `i64`, `i64` (sizes),
-  /// - `i64`, `i64` (strides).
-  /// These types can be recomposed to a memref descriptor struct.
-  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
-                                                 bool unpackAggregates) const;
-
-  /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
-  /// that will form the unranked memref descriptor. In particular, this list
-  /// contains:
-  /// - an integer rank, followed by
-  /// - a pointer to the memref descriptor struct.
-  /// For example, memref<*xf32> is converted to the following list:
-  /// i64 (rank)
-  /// !llvm<"i8*"> (type-erased pointer).
-  /// These types can be recomposed to a unranked memref descriptor struct.
-  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
-
   /// Convert an unranked memref type to an LLVM type that captures the
   /// runtime rank and a pointer to the static ranked memref desc
   Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..1a7951282d3f78 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const DataLayoutAnalysis *analysis)
     : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
 
+/// Helper function that checks if the given value range is a bare pointer.
+static bool isBarePointer(ValueRange values) {
+  return values.size() == 1 &&
+         isa<LLVM::LLVMPointerType>(values.front().getType());
+};
+
+/// Pack SSA values into an unranked memref descriptor struct.
+static Value packUnrankedMemRefDesc(OpBuilder &builder,
+                                    UnrankedMemRefType resultType,
+                                    ValueRange inputs, Location loc,
+                                    const LLVMTypeConverter &converter) {
+  // Note: Bare pointers are not supported for unranked memrefs because a
+  // memref descriptor cannot be built just from a bare pointer.
+  if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
+    return Value();
+  return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                        inputs);
+}
+
+/// Pack SSA values into a ranked memref descriptor struct.
+static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  const LLVMTypeConverter &converter) {
+  assert(resultType && "expected non-null result type");
+  if (isBarePointer(inputs))
+    return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                             resultType, inputs[0]);
+  if (TypeRange(inputs) ==
+      converter.getMemRefDescriptorFields(resultType,
+                                          /*unpackAggregates=*/true))
+    return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
+  // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+  // This materialization function cannot be used.
+  return Value();
+}
+
+/// MemRef descriptor elements -> UnrankedMemRefType
+static Value unrankedMemRefMaterialization(OpBuilder &builder,
+                                           UnrankedMemRefType resultType,
+                                           ValueRange inputs, Location loc,
+                                           const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type
+  // `resultType`, so insert a cast from the memref descriptor type
+  // (!llvm.struct) to the original memref type.
+  Value packed =
+      packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+};
+
+/// MemRef descriptor elements -> MemRefType
+static Value rankedMemRefMaterialization(OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc,
+                                         const LLVMTypeConverter &converter) {
+  // An argument materialization must return a value of type `resultType`,
+  // so insert a cast from the memref descriptor type (!llvm.struct) to the
+  // original memref type.
+  Value packed =
+      packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
+  if (!packed)
+    return Value();
+  return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+      .getResult(0);
+}
+
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
 
-  // Helper function that checks if the given value range is a bare pointer.
-  auto isBarePointer = [](ValueRange values) {
-    return values.size() == 1 &&
-           isa<LLVM::LLVMPointerType>(values.front().getType());
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packUnrankedMemRefDesc =
-      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
-          Location loc, LLVMTypeConverter &converter) -> Value {
-    // Note: Bare pointers are not supported for unranked memrefs because a
-    // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
-      return Value();
-    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
-                                          inputs);
-  };
-
-  // MemRef descriptor elements -> UnrankedMemRefType
-  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
-                                          UnrankedMemRefType resultType,
-                                          ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
-    Value packed =
-        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
-  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
-  // must be passed explicitly.
-  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
-                                  ValueRange inputs, Location loc,
-                                  LLVMTypeConverter &converter) -> Value {
-    assert(resultType && "expected non-null result type");
-    if (isBarePointer(inputs))
-      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
-                                               resultType, inputs[0]);
-    if (TypeRange(inputs) ==
-        converter.getMemRefDescriptorFields(resultType,
-                                            /*unpackAggregates=*/true))
-      return MemRefDescriptor::pack(builder, loc, converter, resultType,
-                                    inputs);
-    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
-    // This materialization function cannot be used.
-    return Value();
-  };
-
-  // MemRef descriptor elements -> MemRefType
-  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
-                                         MemRefType resultType,
-                                         ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type `resultType`,
-    // so insert a cast from the memref descriptor type (!llvm.struct) to the
-    // original memref type.
-    Value packed =
-        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
-    if (!packed)
-      return Value();
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
-        .getResult(0);
-  };
-
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
-  addSourceMaterialization(unrakedMemRefMaterialization);
-  addSourceMaterialization(rankedMemRefMaterialization);
+  addArgumentMaterialization([&](OpBuilder &builder,
+                                 UnrankedMemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                 ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder,
+                               UnrankedMemRefType resultType, ValueRange inputs,
+                               Location loc) {
+    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+                                         *this);
+  });
+  addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                               ValueRange inputs, Location loc) {
+    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+  });
 
   // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,



More information about the Mlir-commits mailing list