[Mlir-commits] [mlir] [mlir] Fix use-after-free in #117513 (PR #120968)
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 23 06:07:23 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/120968
Fix a use-after-free in #117513. Free-standing lambdas should not be defined inside of the `LLVMTypeConverter` constructor because they go out of scope.
>From 935d441c9e7353e8c91e0dfd75b1418ad9c3af9d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 23 Dec 2024 15:04:45 +0100
Subject: [PATCH] [mlir] Fix use-after-free in #117513
---
.../Conversion/LLVMCommon/TypeConverter.h | 70 ++++----
.../Conversion/LLVMCommon/TypeConverter.cpp | 160 ++++++++++--------
2 files changed, 123 insertions(+), 107 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..38b5e492a8ed8f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -161,6 +161,41 @@ class LLVMTypeConverter : public TypeConverter {
/// Check if a memref type can be converted to a bare pointer.
static bool canConvertToBarePtr(BaseMemRefType type);
+ /// Convert a memref type into a list of LLVM IR types that will form the
+ /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
+ /// arrays in the descriptors are unpacked to individual index-typed elements,
+ /// else they are kept as rank-sized arrays of index type. In particular,
+ /// the list will contain:
+ /// - two pointers to the memref element type, followed by
+ /// - an index-typed offset, followed by
+ /// - (if unpackAggregates = true)
+ /// - one index-typed size per dimension of the memref, followed by
+ /// - one index-typed stride per dimension of the memref.
+ /// - (if unpackArrregates = false)
+ /// - one rank-sized array of index-type for the size of each dimension
+ /// - one rank-sized array of index-type for the stride of each dimension
+ ///
+ /// For example, memref<?x?xf32> is converted to the following list:
+ /// - `!llvm<"float*">` (allocated pointer),
+ /// - `!llvm<"float*">` (aligned pointer),
+ /// - `i64` (offset),
+ /// - `i64`, `i64` (sizes),
+ /// - `i64`, `i64` (strides).
+ /// These types can be recomposed to a memref descriptor struct.
+ SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+ bool unpackAggregates) const;
+
+ /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
+ /// that will form the unranked memref descriptor. In particular, this list
+ /// contains:
+ /// - an integer rank, followed by
+ /// - a pointer to the memref descriptor struct.
+ /// For example, memref<*xf32> is converted to the following list:
+ /// i64 (rank)
+ /// !llvm<"i8*"> (type-erased pointer).
+ /// These types can be recomposed to a unranked memref descriptor struct.
+ SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
+
protected:
/// Pointer to the LLVM dialect.
LLVM::LLVMDialect *llvmDialect;
@@ -213,41 +248,6 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a memref type into an LLVM type that captures the relevant data.
Type convertMemRefType(MemRefType type) const;
- /// Convert a memref type into a list of LLVM IR types that will form the
- /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides`
- /// arrays in the descriptors are unpacked to individual index-typed elements,
- /// else they are kept as rank-sized arrays of index type. In particular,
- /// the list will contain:
- /// - two pointers to the memref element type, followed by
- /// - an index-typed offset, followed by
- /// - (if unpackAggregates = true)
- /// - one index-typed size per dimension of the memref, followed by
- /// - one index-typed stride per dimension of the memref.
- /// - (if unpackArrregates = false)
- /// - one rank-sized array of index-type for the size of each dimension
- /// - one rank-sized array of index-type for the stride of each dimension
- ///
- /// For example, memref<?x?xf32> is converted to the following list:
- /// - `!llvm<"float*">` (allocated pointer),
- /// - `!llvm<"float*">` (aligned pointer),
- /// - `i64` (offset),
- /// - `i64`, `i64` (sizes),
- /// - `i64`, `i64` (strides).
- /// These types can be recomposed to a memref descriptor struct.
- SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
- bool unpackAggregates) const;
-
- /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
- /// that will form the unranked memref descriptor. In particular, this list
- /// contains:
- /// - an integer rank, followed by
- /// - a pointer to the memref descriptor struct.
- /// For example, memref<*xf32> is converted to the following list:
- /// i64 (rank)
- /// !llvm<"i8*"> (type-erased pointer).
- /// These types can be recomposed to a unranked memref descriptor struct.
- SmallVector<Type, 2> getUnrankedMemRefDescriptorFields() const;
-
/// Convert an unranked memref type to an LLVM type that captures the
/// runtime rank and a pointer to the static ranked memref desc
Type convertUnrankedMemRefType(UnrankedMemRefType type) const;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..1a7951282d3f78 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const DataLayoutAnalysis *analysis)
: LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
+/// Helper function that checks if the given value range is a bare pointer.
+static bool isBarePointer(ValueRange values) {
+ return values.size() == 1 &&
+ isa<LLVM::LLVMPointerType>(values.front().getType());
+};
+
+/// Pack SSA values into an unranked memref descriptor struct.
+static Value packUnrankedMemRefDesc(OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc,
+ const LLVMTypeConverter &converter) {
+ // Note: Bare pointers are not supported for unranked memrefs because a
+ // memref descriptor cannot be built just from a bare pointer.
+ if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
+ return Value();
+ return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+ inputs);
+}
+
+/// Pack SSA values into a ranked memref descriptor struct.
+static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc,
+ const LLVMTypeConverter &converter) {
+ assert(resultType && "expected non-null result type");
+ if (isBarePointer(inputs))
+ return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+ resultType, inputs[0]);
+ if (TypeRange(inputs) ==
+ converter.getMemRefDescriptorFields(resultType,
+ /*unpackAggregates=*/true))
+ return MemRefDescriptor::pack(builder, loc, converter, resultType, inputs);
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+ // This materialization function cannot be used.
+ return Value();
+}
+
+/// MemRef descriptor elements -> UnrankedMemRefType
+static Value unrankedMemRefMaterialization(OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc,
+ const LLVMTypeConverter &converter) {
+ // An argument materialization must return a value of type
+ // `resultType`, so insert a cast from the memref descriptor type
+ // (!llvm.struct) to the original memref type.
+ Value packed =
+ packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
+ if (!packed)
+ return Value();
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ .getResult(0);
+};
+
+/// MemRef descriptor elements -> MemRefType
+static Value rankedMemRefMaterialization(OpBuilder &builder,
+ MemRefType resultType,
+ ValueRange inputs, Location loc,
+ const LLVMTypeConverter &converter) {
+ // An argument materialization must return a value of type `resultType`,
+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
+ // original memref type.
+ Value packed =
+ packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
+ if (!packed)
+ return Value();
+ return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+ .getResult(0);
+}
+
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
.getResult(0);
});
- // Helper function that checks if the given value range is a bare pointer.
- auto isBarePointer = [](ValueRange values) {
- return values.size() == 1 &&
- isa<LLVM::LLVMPointerType>(values.front().getType());
- };
-
- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
- // must be passed explicitly.
- auto packUnrankedMemRefDesc =
- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
- Location loc, LLVMTypeConverter &converter) -> Value {
- // Note: Bare pointers are not supported for unranked memrefs because a
- // memref descriptor cannot be built just from a bare pointer.
- if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
- return Value();
- return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
- inputs);
- };
-
- // MemRef descriptor elements -> UnrankedMemRefType
- auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
- UnrankedMemRefType resultType,
- ValueRange inputs, Location loc) {
- // An argument materialization must return a value of type
- // `resultType`, so insert a cast from the memref descriptor type
- // (!llvm.struct) to the original memref type.
- Value packed =
- packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
- if (!packed)
- return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
- .getResult(0);
- };
-
- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
- // must be passed explicitly.
- auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
- ValueRange inputs, Location loc,
- LLVMTypeConverter &converter) -> Value {
- assert(resultType && "expected non-null result type");
- if (isBarePointer(inputs))
- return MemRefDescriptor::fromStaticShape(builder, loc, converter,
- resultType, inputs[0]);
- if (TypeRange(inputs) ==
- converter.getMemRefDescriptorFields(resultType,
- /*unpackAggregates=*/true))
- return MemRefDescriptor::pack(builder, loc, converter, resultType,
- inputs);
- // The inputs are neither a bare pointer nor an unpacked memref descriptor.
- // This materialization function cannot be used.
- return Value();
- };
-
- // MemRef descriptor elements -> MemRefType
- auto rankedMemRefMaterialization = [&](OpBuilder &builder,
- MemRefType resultType,
- ValueRange inputs, Location loc) {
- // An argument materialization must return a value of type `resultType`,
- // so insert a cast from the memref descriptor type (!llvm.struct) to the
- // original memref type.
- Value packed =
- packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
- if (!packed)
- return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
- .getResult(0);
- };
-
// Argument materializations convert from the new block argument types
// (multiple SSA values that make up a memref descriptor) back to the
// original block argument type.
- addArgumentMaterialization(unrakedMemRefMaterialization);
- addArgumentMaterialization(rankedMemRefMaterialization);
- addSourceMaterialization(unrakedMemRefMaterialization);
- addSourceMaterialization(rankedMemRefMaterialization);
+ addArgumentMaterialization([&](OpBuilder &builder,
+ UnrankedMemRefType resultType,
+ ValueRange inputs, Location loc) {
+ return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+ *this);
+ });
+ addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc) {
+ return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+ });
+ addSourceMaterialization([&](OpBuilder &builder,
+ UnrankedMemRefType resultType, ValueRange inputs,
+ Location loc) {
+ return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
+ *this);
+ });
+ addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
+ ValueRange inputs, Location loc) {
+ return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
+ });
// Bare pointer -> Packed MemRef descriptor
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
More information about the Mlir-commits
mailing list