[Mlir-commits] [mlir] [mlir][LLVM][NFC] Simplify `computeSizes` function (PR #153588)

Matthias Springer llvmlistbot at llvm.org
Thu Aug 14 07:46:51 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/153588

Rename `computeSizes` to `computeSize` and make it compute just a single size. This is in preparation of adding 1:N support to the Func->LLVM lowering patterns.


>From aa2cff134958965cd55f2311f5fcb1ab2fff48e9 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Thu, 14 Aug 2025 14:45:27 +0000
Subject: [PATCH] [mlir][LLVM] Simplify `computeSizes` function

---
 .../Conversion/LLVMCommon/MemRefBuilder.h     | 14 ++---
 .../Conversion/LLVMCommon/MemRefBuilder.cpp   | 62 ++++++++-----------
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    | 10 +--
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 13 ++--
 4 files changed, 41 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
index d5055f023cdc8..8e86808cc424a 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h
@@ -189,15 +189,13 @@ class UnrankedMemRefDescriptor : public StructBuilder {
   /// `unpack`.
   static unsigned getNumUnpackedValues() { return 2; }
 
-  /// Builds IR computing the sizes in bytes (suitable for opaque allocation)
-  /// and appends the corresponding values into `sizes`. `addressSpaces`
-  /// which must have the same length as `values`, is needed to handle layouts
-  /// where sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
-  static void computeSizes(OpBuilder &builder, Location loc,
+  /// Builds and returns IR computing the size in bytes (suitable for opaque
+  /// allocation). `addressSpace` is needed to handle layouts where
+  /// sizeof(ptr addrspace(N)) != sizeof(ptr addrspace(0)).
+  static Value computeSize(OpBuilder &builder, Location loc,
                            const LLVMTypeConverter &typeConverter,
-                           ArrayRef<UnrankedMemRefDescriptor> values,
-                           ArrayRef<unsigned> addressSpaces,
-                           SmallVectorImpl<Value> &sizes);
+                           UnrankedMemRefDescriptor desc,
+                           unsigned addressSpace);
 
   /// TODO: The following accessors don't take alignment rules between elements
   /// of the descriptor struct into account. For some architectures, it might be
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f324b86..522e91421ff55 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
   results.push_back(d.memRefDescPtr(builder, loc));
 }
 
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
     OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
-    ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
-    SmallVectorImpl<Value> &sizes) {
-  if (values.empty())
-    return;
-  assert(values.size() == addressSpaces.size() &&
-         "must provide address space for each descriptor");
+    UnrankedMemRefDescriptor desc, unsigned addressSpace) {
   // Cache the index type.
   Type indexType = typeConverter.getIndexType();
 
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
       builder, loc, indexType,
       llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
 
-  sizes.reserve(sizes.size() + values.size());
-  for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
-    // Emit IR computing the memory necessary to store the descriptor. This
-    // assumes the descriptor to be
-    //   { type*, type*, index, index[rank], index[rank] }
-    // and densely packed, so the total size is
-    //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
-    // TODO: consider including the actual size (including eventual padding due
-    // to data layout) into the unranked descriptor.
-    Value pointerSize = createIndexAttrConstant(
-        builder, loc, indexType,
-        llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
-    Value doublePointerSize =
-        LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
-    // (1 + 2 * rank) * sizeof(index)
-    Value rank = desc.rank(builder, loc);
-    Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
-    Value doubleRankIncremented =
-        LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
-    Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
-                                              doubleRankIncremented, indexSize);
-
-    // Total allocation size.
-    Value allocationSize = LLVM::AddOp::create(
-        builder, loc, indexType, doublePointerSize, rankIndexSize);
-    sizes.push_back(allocationSize);
-  }
+  // Emit IR computing the memory necessary to store the descriptor. This
+  // assumes the descriptor to be
+  //   { type*, type*, index, index[rank], index[rank] }
+  // and densely packed, so the total size is
+  //   2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+  // TODO: consider including the actual size (including eventual padding due
+  // to data layout) into the unranked descriptor.
+  Value pointerSize = createIndexAttrConstant(
+      builder, loc, indexType,
+      llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+  Value doublePointerSize =
+      LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+  // (1 + 2 * rank) * sizeof(index)
+  Value rank = desc.rank(builder, loc);
+  Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+  Value doubleRankIncremented =
+      LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+  Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+                                            doubleRankIncremented, indexSize);
+
+  // Total allocation size.
+  Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+                                             doublePointerSize, rankIndexSize);
+  return allocationSize;
 }
 
 Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044f1fd32..72f41fd01fe7c 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -239,12 +239,6 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
   if (unrankedMemrefs.empty())
     return success();
 
-  // Compute allocation sizes.
-  SmallVector<Value> sizes;
-  UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
-                                         unrankedMemrefs, unrankedAddressSpaces,
-                                         sizes);
-
   // Get frequently used types.
   Type indexType = getTypeConverter()->getIndexType();
 
@@ -267,8 +261,10 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
     Type type = origTypes[i];
     if (!isa<UnrankedMemRefType>(type))
       continue;
-    Value allocationSize = sizes[unrankedMemrefPos++];
     UnrankedMemRefDescriptor desc(operands[i]);
+    Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+        builder, loc, *getTypeConverter(), desc,
+        unrankedAddressSpaces[unrankedMemrefPos++]);
 
     // Allocate memory, copy, and free the source if necessary.
     Value memory =
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 9216e2a35a5ae..262e0e7a30c63 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
       auto result = UnrankedMemRefDescriptor::poison(
           rewriter, loc, typeConverter->convertType(resultTypeU));
       result.setRank(rewriter, loc, rank);
-      SmallVector<Value, 1> sizes;
-      UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                             result, resultAddrSpace, sizes);
-      Value resultUnderlyingSize = sizes.front();
+      Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+          rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
       Value resultUnderlyingDesc =
           LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
                                  rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ struct MemRefReshapeOpLowering
     auto targetDesc = UnrankedMemRefDescriptor::poison(
         rewriter, loc, typeConverter->convertType(targetType));
     targetDesc.setRank(rewriter, loc, resultRank);
-    SmallVector<Value, 4> sizes;
-    UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
-                                           targetDesc, addressSpace, sizes);
+    Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+        rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
     Value underlyingDescPtr = LLVM::AllocaOp::create(
         rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
-        sizes.front());
+        allocationSize);
     targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
 
     // Extract pointers and offset from the source memref.



More information about the Mlir-commits mailing list