[Mlir-commits] [mlir] dc297cb - [mlir][memref][spirv] Add conversion for memref.extract_aligned_pointer_as_index to SPIR-V (#86750)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 07:39:18 PDT 2025
Author: Md Abdullah Shahneous Bari
Date: 2025-06-03T09:39:14-05:00
New Revision: dc297cbc9ad2a0a18b530708c39b13bd431bf237
URL: https://github.com/llvm/llvm-project/commit/dc297cbc9ad2a0a18b530708c39b13bd431bf237
DIFF: https://github.com/llvm/llvm-project/commit/dc297cbc9ad2a0a18b530708c39b13bd431bf237.diff
LOG: [mlir][memref][spirv] Add conversion for memref.extract_aligned_pointer_as_index to SPIR-V (#86750)
Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
Index conversion is done based on 'use-64bit-index' option.
Added:
Modified:
mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index fdf799a20efdd..ff5b762a969d8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -307,6 +307,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
}
};
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern final
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -921,6 +932,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type indexType = typeConverter.getIndexType();
+ rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+ adaptor.getSource());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -928,10 +953,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
namespace mlir {
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
- ReinterpretCastPattern, CastPattern>(typeConverter,
- patterns.getContext());
+ patterns
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..d0ddac8cd801c 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \
+// RUN: | FileCheck --check-prefix=CHECK64 %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
@@ -420,6 +422,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Kernel, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+
// -----
// Check nontemporal attribute
More information about the Mlir-commits
mailing list