[Mlir-commits] [mlir] [mlir][memref][spirv] Add conversion for memref.extract_aligned_point… (PR #86750)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Mon Jun 2 14:26:04 PDT 2025


https://github.com/mshahneo updated https://github.com/llvm/llvm-project/pull/86750

>From dbc399a20a0f1ffd56078db99ca23e2c31fac154 Mon Sep 17 00:00:00 2001
From: Md Abdullah Shahneous Bari <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 26 Mar 2024 23:50:43 +0000
Subject: [PATCH 1/2] [mlir][memref][spirv] Add conversion for
 memref.extract_aligned_pointer_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
Index conversion is done based on 'use-64bit-index' option.
---
 .../MemRefToSPIRV/MemRefToSPIRV.cpp           | 36 ++++++++++++++---
 .../MemRefToSPIRV/memref-to-spirv.mlir        | 40 ++++++++++++++++++-
 2 files changed, 70 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index fdf799a20efdd..d6f40aebf1cb9 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -307,6 +307,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
   }
 };
 
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern
+    : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -921,6 +932,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+    memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  Type indexType = typeConverter.getIndexType();
+  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+                                                      adaptor.getSource());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -928,10 +953,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
-               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
-               ReinterpretCastPattern, CastPattern>(typeConverter,
-                                                    patterns.getContext());
+  patterns
+      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+           CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+          typeConverter, patterns.getContext());
 }
 } // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 8906de9db3724..b81440bc151a2 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
@@ -420,6 +421,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 
 }
 
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+
 // -----
 
 // Check nontemporal attribute

>From 0a84e910e25d50723d64ad737d71f6979514c50c Mon Sep 17 00:00:00 2001
From: "Shahneous Bari, Md Abdullah" <md.abdullah.shahneous.bari at intel.com>
Date: Tue, 6 May 2025 17:15:57 +0000
Subject: [PATCH 2/2] Address review comments.

---
 mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp   |  2 +-
 .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir     | 11 ++++++-----
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index d6f40aebf1cb9..ff5b762a969d8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -308,7 +308,7 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
 };
 
 /// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
-class ExtractAlignedPointerAsIndexOpPattern
+class ExtractAlignedPointerAsIndexOpPattern final
     : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index b81440bc151a2..d0ddac8cd801c 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
-// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \
+// RUN: | FileCheck --check-prefix=CHECK64 %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
@@ -424,14 +425,14 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 // -----
 
 module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Kernel, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
 } {
 // CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
 func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
   // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
-  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
   // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
   // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
 
   // CHECK: return %[[R:.*]] : index
@@ -442,14 +443,14 @@ func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.stor
 // -----
 
 module attributes {
-  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Shader, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
 } {
 // CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
 func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
   %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
   // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
-  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
   // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
   // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
 
   // CHECK: return %[[R:.*]] : index



More information about the Mlir-commits mailing list