[Mlir-commits] [mlir] Unranked support for extract_aligned_pointer_as_index (PR #93908)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 30 17:40:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref
Author: Spenser Bauman (sabauma)
<details>
<summary>Changes</summary>
memref.extract_aligned_pointer_as_index currently does not support unranked inputs. This lack of support interferes with the folding operations in the expand-strided-metadata pass.
%r = memref.reinterpret_cast %arg0 to
offset: [0],
sizes: [],
strides: [] : memref<*xf32> to memref<f32>
%i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
Patterns like this occur when bufferizing operations on unranked tensors.
This change modifies the extract_aligned_pointer_as_index operation to support unranked inputs with corresponding support in the MemRef->LLVM conversion.
---
Full diff: https://github.com/llvm/llvm-project/pull/93908.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+1-1)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+19-3)
- (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+15)
- (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+13)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..df40e7a17a15f 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -892,7 +892,7 @@ def MemRef_ExtractAlignedPointerAsIndexOp :
}];
let arguments = (ins
- AnyStridedMemRef:$source
+ AnyRankedOrUnrankedMemRef:$source
);
let results = (outs Index:$aligned_pointer);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 2dc42f0a85e66..82c4b04656b33 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1590,10 +1590,26 @@ class ConvertExtractAlignedPointerAsIndex
matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefDescriptor desc(adaptor.getSource());
+ BaseMemRefType sourceTy = extractOp.getSource().getType();
+
+ Value alignedPtr;
+ if (sourceTy.hasRank()) {
+ MemRefDescriptor desc(adaptor.getSource());
+ alignedPtr = desc.alignedPtr(rewriter, extractOp->getLoc());
+ } else {
+ auto elementPtrTy = LLVM::LLVMPointerType::get(
+ rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
+
+ UnrankedMemRefDescriptor desc(adaptor.getSource());
+ Value descPtr = desc.memRefDescPtr(rewriter, extractOp->getLoc());
+
+ alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
+ rewriter, extractOp->getLoc(), *getTypeConverter(), descPtr,
+ elementPtrTy);
+ }
+
rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
- extractOp, getTypeConverter()->getIndexType(),
- desc.alignedPtr(rewriter, extractOp->getLoc()));
+ extractOp, getTypeConverter()->getIndexType(), alignedPtr);
return success();
}
};
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index baf9cfe610a5a..882804132e66d 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -598,6 +598,21 @@ func.func @extract_aligned_pointer_as_index(%m: memref<?xf32>) -> index {
// -----
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_unranked
+func.func @extract_aligned_pointer_as_index_unranked(%m: memref<*xf32>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<*xf32> -> index
+ // CHECK: %[[PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)>
+ // CHECK: %[[ALIGNED_FIELD:.*]] = llvm.getelementptr %[[PTR]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
+ // CHECK: %[[ALIGNED_PTR:.*]] = llvm.load %[[ALIGNED_FIELD]] : !llvm.ptr -> !llvm.ptr
+ // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[ALIGNED_PTR]] : !llvm.ptr to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R]] : index
+ return %0: index
+}
+
+// -----
+
// CHECK-LABEL: func @extract_strided_metadata(
// CHECK-SAME: %[[ARG:.*]]: memref
// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 3bd6b7c1fd791..d884ade319532 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -899,6 +899,19 @@ func.func @extract_aligned_pointer_as_index(%arg0: memref<f32>) -> index {
// -----
+// CHECK-LABEL: extract_aligned_pointer_as_index_of_unranked_source
+// CHECK-SAME: (%[[ARG0:.*]]: memref<*xf32>
+func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf32>) -> index {
+ // CHECK: %[[I:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<*xf32> -> index
+ // CHECK: return %[[I]]
+
+ %r = memref.reinterpret_cast %arg0 to offset: [0], sizes: [], strides: [] : memref<*xf32> to memref<f32>
+ %i = memref.extract_aligned_pointer_as_index %r : memref<f32> -> index
+ return %i : index
+}
+
+// -----
+
// Check that we simplify collapse_shape into
// reinterpret_cast(extract_strided_metadata) + <some math>
//
``````````
</details>
https://github.com/llvm/llvm-project/pull/93908
More information about the Mlir-commits
mailing list