[Mlir-commits] [mlir] Unranked support for extract_aligned_pointer_as_index (PR #93908)

Spenser Bauman llvmlistbot at llvm.org
Thu May 30 17:40:26 PDT 2024


https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/93908

memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.

    %r = memref.reinterpret_cast %arg0 to
        offset: [0],
        sizes: [],
        strides: [] : memref<*xf32> to memref<f32>

    %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked tensors.

This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.

>From 17ec1a587e8ddff1a62bdb93ca256b8e34d5ac79 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sabauma at fastmail>
Date: Thu, 30 May 2024 19:06:32 -0400
Subject: [PATCH] Unranked support for extract_aligned_pointer_as_index

memref.extract_aligned_pointer_as_index currently does not support
unranked inputs. This lack of support interferes with the folding
operations in the expand-strided-metadata pass.

    %r = memref.reinterpret_cast %arg0 to
        offset: [0],
        sizes: [],
        strides: [] : memref<*xf32> to memref<f32>

    %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked
tensors.

This change modifies the extract_aligned_pointer_as_index operation to
support unranked inputs with corresponding support in the MemRef->LLVM
conversion.
---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  2 +-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  | 22 ++++++++++++++++---
 .../MemRefToLLVM/memref-to-llvm.mlir          | 15 +++++++++++++
 .../MemRef/expand-strided-metadata.mlir       | 13 +++++++++++
 4 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..df40e7a17a15f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
   }];
 
   let arguments = (ins
-    AnyStridedMemRef:$source
+    AnyRankedOrUnrankedMemRef:$source
   );
   let results = (outs Index:$aligned_pointer);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2dc42f0a85e66..82c4b04656b33 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefDescriptor desc(adaptor.getSource());
+    BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+    Value alignedPtr;
+    if (sourceTy.hasRank()) {
+      MemRefDescriptor desc(adaptor.getSource());
+      alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+    } else {
+      auto elementPtrTy = LLVM::LLVMPointerType::get(
+          rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+      UnrankedMemRefDescriptor desc(adaptor.getSource());
+      Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+      alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+          rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+          elementPtrTy);
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
-        extractOp, getTypeConverter()->getIndexType(),
-        desc.alignedPtr(rewriter, extractOp->getLoc()));
+        extractOp, getTypeConverter()->getIndexType(), alignedPtr);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index baf9cfe610a5a..882804132e66d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
+func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> 
+  // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
+  // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R]] : index
+  return %0: index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_strided_metadata(
 // CHECK-SAME: %[[ARG:.*]]: memref
 // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 3bd6b7c1fd791..d884ade319532 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
 
 // -----
 
+// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
+func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
+  // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
+  // CHECK: return %[[I]]
+
+  %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
+  %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
+  return %i : index
+}
+
+// -----
+
 // Check that we simplify collapse_shape into
 // reinterpret_cast(extract_strided_metadata) + <some math>
 //



More information about the Mlir-commits mailing list