[Mlir-commits] [mlir] 846103c - [mlir][memref] Unranked support for extract_aligned_pointer_as_index (#93908)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 13 06:12:22 PDT 2024


Author: Spenser Bauman
Date: 2024-06-13T09:12:18-04:00
New Revision: 846103c7e3899a654c470fa360b22d35df888427

URL: https://github.com/llvm/llvm-project/commit/846103c7e3899a654c470fa360b22d35df888427
DIFF: https://github.com/llvm/llvm-project/commit/846103c7e3899a654c470fa360b22d35df888427.diff

LOG: [mlir][memref] Unranked support for extract_aligned_pointer_as_index (#93908)

memref.extract_aligned_pointer_as_index currently does not support
unranked inputs. This lack of support interferes with the folding
operations in the expand-strided-metadata pass.

    %r = memref.reinterpret_cast %arg0 to
        offset: [0],
        sizes: [],
        strides: [] : memref<*xf32> to memref<f32>

%i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index

Patterns like this occur when bufferizing operations on unranked
tensors.

This change modifies the extract_aligned_pointer_as_index operation to
support unranked inputs with corresponding support in the MemRef->LLVM
conversion.

Co-authored-by: Spenser Bauman <sabauma at fastmail>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..df40e7a17a15f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
   }];
 
   let arguments = (ins
-    AnyStridedMemRef:$source
+    AnyRankedOrUnrankedMemRef:$source
   );
   let results = (outs Index:$aligned_pointer);
 

diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 7ef5a77fcb42c..1933c1dfcfba4 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
                   OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefDescriptor desc(adaptor.getSource());
+    BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+    Value alignedPtr;
+    if (sourceTy.hasRank()) {
+      MemRefDescriptor desc(adaptor.getSource());
+      alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+    } else {
+      auto elementPtrTy = LLVM::LLVMPointerType::get(
+          rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+      UnrankedMemRefDescriptor desc(adaptor.getSource());
+      Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+      alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+          rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+          elementPtrTy);
+    }
+
     rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
-        extractOp, getTypeConverter()->getIndexType(),
-        desc.alignedPtr(rewriter, extractOp->getLoc()));
+        extractOp, getTypeConverter()->getIndexType(), alignedPtr);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index baf9cfe610a5a..882804132e66d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
 
 // -----
 
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
+func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
+  // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> 
+  // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+  // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
+  // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R]] : index
+  return %0: index
+}
+
+// -----
+
 // CHECK-LABEL: func @extract_strided_metadata(
 // CHECK-SAME: %[[ARG:.*]]: memref
 // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>

diff  --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 3bd6b7c1fd791..d884ade319532 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
 
 // -----
 
+// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
+func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
+  // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
+  // CHECK: return %[[I]]
+
+  %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
+  %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
+  return %i : index
+}
+
+// -----
+
 // Check that we simplify collapse_shape into
 // reinterpret_cast(extract_strided_metadata) + <some math>
 //


        


More information about the Mlir-commits mailing list