[Mlir-commits] [mlir] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization checks (PR #116532)

Matthias Springer llvmlistbot at llvm.org
Sun Nov 17 00:11:28 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/116532

This commit adds extra checks to the MemRef argument materializations in the LLVM type converter. These materializations construct a `MemRefType`/`UnrankedMemRefType` from the unpacked elements of a MemRef descriptor or from a bare pointer.

The extra checks ensure that the inputs to the materialization function are correct. It is possible that a user added extra type conversion rules that convert MemRef types in a different way and the extra checks ensure that we construct a MemRef descriptor only if the inputs are what we expect.

This commit also drops a check around bare pointer materializations:
```
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
```
This check should not be part of the materialization function. Whether a MemRef block argument is converted into a MemRef descriptor or a bare pointer is decided in the lowering pattern. At the point of time when materialization functions are executed, we already made that decision and we should just materialize whatever format is requested.

>From df3d0f2e3ea35938c2470eacc294eb30534fda3b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 17 Nov 2024 09:00:45 +0100
Subject: [PATCH] [mlir][LLVM] `LLVMTypeConverter`: Tighten materialization
 checks

---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 32 ++++++++++---------
 1 file changed, 17 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ce91424e7a577e..59b0f5c9b09bcd 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Helper function that checks if the given value range is a bare pointer.
+  auto isBarePointer = [](ValueRange values) {
+    return values.size() == 1 &&
+           isa<LLVM::LLVMPointerType>(values.front().getType());
+  };
+
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder,
                                  UnrankedMemRefType resultType,
                                  ValueRange inputs, Location loc) {
-    if (inputs.size() == 1) {
-      // Bare pointers are not supported for unranked memrefs because a
-      // memref descriptor cannot be built just from a bare pointer.
+    // Note: Bare pointers are not supported for unranked memrefs because a
+    // memref descriptor cannot be built just from a bare pointer.
+    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
       return Value();
-    }
     Value desc =
         UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
     // An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs, Location loc) {
     Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return Value();
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
-        return Value();
+    if (isBarePointer(inputs)) {
       desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
-    } else {
+    } else if (TypeRange(inputs) ==
+               getMemRefDescriptorFields(resultType,
+                                         /*unpackAggregates=*/true)) {
       desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    } else {
+      // The inputs are neither a bare pointer nor an unpacked memref
+      // descriptor. This materialization function cannot be used.
+      return Value();
     }
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the



More information about the Mlir-commits mailing list