[Mlir-commits] [mlir] [mlir][xegpu] Add support for `vector.transfer_read/write` on SLM buffers (PR #192757)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 17 17:14:34 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Charitha Saumya (charithaintc)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/192757.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td (+9) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+1-1) 
- (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (+83-15) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+17) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (-12) 
- (modified) mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir (+49-5) 
- (modified) mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir (+46) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 84fd8f9e0060c..dc6e972e3b5a6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,6 +38,15 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
+      /// Checks if the given shape can be evenly distributed based on the layout
+      /// and data factors provided by the LayoutAttr.
+      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
+
+      /// Checks if the given memref type represents shared local memory (SLM).
+      /// Returns true if the memory space is address space 3, MemorySpace::SLM,
+      /// xevm::AddrSpace::SHARED, or a GPU workgroup memory address space.
+      static bool isSharedMemory(const MemRefType &memrefTy);
+
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
       template<typename T, typename U>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 33eab14e9dfd8..c76c9e5b3752d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -39,7 +39,7 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
   let mnemonic = typeMnemonic;
 }
 
-def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
+def isSharedPred : CPred<"XeGPUDialect::isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
 class StaticShared1DMemRefOf<list<Type> allowedTypes> :
   ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
      "reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ",
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index bbb6340f14c51..5769489aa2d43 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -547,31 +547,68 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
+    auto readMemTy = cast<MemRefType>(readOp.getShapedType());
+    VectorType loadedVecTy = readOp.getVectorType();
+    bool isOutOfBounds = readOp.hasOutOfBoundsDim();
+    // Check if the memref has address space 3 (shared local memory)
+    bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(readMemTy);
 
     // TODO:This check needs to be replaced with proper uArch capability check
     auto chip = xegpu::getChipStr(readOp);
-    if (chip != "pvc" && chip != "bmg") {
-      // lower to scattered load Op if the target HW doesn't have 2d block load
-      // support
+    // Lower to scattered load Op if the target HW doesn't have 2d block load
+    // support and the load is not from shared memory.
+    if (chip != "pvc" && chip != "bmg" && !isSharedMemory) {
+
       // TODO: add support for OutOfBound access
       if (readOp.hasOutOfBoundsDim())
         return failure();
       return lowerToScatteredLoadOp(readOp, rewriter);
     }
 
-    VectorType loadedVecTy = readOp.getVectorType();
-
-    // Lower using load.gather in 1D case
-    if (loadedVecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim())
+    // Handle the 1D non-SLM case using load.gather.
+    if (loadedVecTy.getRank() == 1 && !readOp.hasOutOfBoundsDim() &&
+        !isSharedMemory)
       return lowerToScatteredLoadOp(readOp, rewriter);
 
     // Perform common data transfer checks.
-    auto readMemTy = cast<MemRefType>(readOp.getShapedType());
+    // TODO: Maybe too strict for SLM case.
     if (failed(
             storeLoadPreconditions(rewriter, readOp, loadedVecTy, readMemTy)))
       return failure();
 
-    bool isOutOfBounds = readOp.hasOutOfBoundsDim();
+    // Handle the SLM case.
+    if (isSharedMemory) {
+      // If the memref is SLM only support 2D case for now.
+      if (loadedVecTy.getRank() != 2)
+        return rewriter.notifyMatchFailure(
+            readOp, "Only 2D vector loads are supported for SLM");
+      AffineMap readMap = readOp.getPermutationMap();
+      if (!readMap.isMinorIdentity())
+        return rewriter.notifyMatchFailure(
+            readOp, "Transpose not supported for SLM loads");
+      // Out of bounds case is not supported for SLM loads.
+      if (isOutOfBounds)
+        return rewriter.notifyMatchFailure(
+            readOp, "Out-of-bounds access is not supported for SLM loads");
+
+      // Create mem_desc for SLM
+      auto memDescType =
+          xegpu::MemDescType::get(rewriter.getContext(), readMemTy.getShape(),
+                                  readMemTy.getElementType(),
+                                  /*mem_layout=*/nullptr);
+      auto createMemDescOp = xegpu::CreateMemDescOp::create(
+          rewriter, loc, memDescType, readOp.getBase());
+      // Convert indices to OpFoldResult for LoadMatrixOp
+      SmallVector<OpFoldResult> indices =
+          getAsOpFoldResult(readOp.getIndices());
+      auto loadMatrixOp = xegpu::LoadMatrixOp::create(
+          rewriter, loc, loadedVecTy, createMemDescOp.getResult(), indices,
+          /*layout=*/nullptr);
+
+      rewriter.replaceOp(readOp, loadMatrixOp.getResult());
+      return success();
+    }
+
     if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
       return rewriter.notifyMatchFailure(
           readOp, "Unsupported non-zero padded out-of-bounds read");
@@ -631,21 +668,24 @@ struct TransferWriteLowering
 
     if (failed(transferPreconditions(rewriter, writeOp)))
       return failure();
+    // Perform common data transfer checks.
+    VectorType vecTy = writeOp.getVectorType();
+    auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
+    // Check if the memref has address space 3 (shared local memory)
+    bool isSharedMemory = xegpu::XeGPUDialect::isSharedMemory(writeMemTy);
 
     // TODO:This check needs to be replaced with proper uArch capability check
     auto chip = xegpu::getChipStr(writeOp);
-    if (chip != "pvc" && chip != "bmg") {
-      // lower to scattered store Op if the target HW doesn't have 2d block
-      // store support
+    // Lower to scattered store Op if the target HW doesn't have 2d block
+    // store support and the memref is not SLM.
+    if (chip != "pvc" && chip != "bmg" && !isSharedMemory) {
+
       // TODO: add support for OutOfBound access
       if (writeOp.hasOutOfBoundsDim())
         return failure();
       return lowerToScatteredStoreOp(writeOp, rewriter);
     }
 
-    // Perform common data transfer checks.
-    VectorType vecTy = writeOp.getVectorType();
-    auto writeMemTy = cast<MemRefType>(writeOp.getShapedType());
     if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy, writeMemTy)))
       return failure();
 
@@ -653,6 +693,34 @@ struct TransferWriteLowering
     if (!map.isMinorIdentity())
       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
 
+    // For shared local memory (address space 3), use create_mem_desc +
+    // store_matrix
+    if (isSharedMemory) {
+      // Only support 2D case for now.
+      if (vecTy.getRank() != 2)
+        return rewriter.notifyMatchFailure(
+            writeOp, "Only 2D vector stores are supported for SLM");
+      // Create mem_desc for SLM
+      auto memDescType =
+          xegpu::MemDescType::get(rewriter.getContext(), writeMemTy.getShape(),
+                                  writeMemTy.getElementType(),
+                                  /*mem_layout=*/nullptr);
+
+      auto createMemDescOp = xegpu::CreateMemDescOp::create(
+          rewriter, loc, memDescType, writeOp.getBase());
+
+      // Convert indices to OpFoldResult for StoreMatrixOp
+      SmallVector<OpFoldResult> indices =
+          getAsOpFoldResult(writeOp.getIndices());
+
+      xegpu::StoreMatrixOp::create(rewriter, loc, writeOp.getVector(),
+                                   createMemDescOp.getResult(), indices,
+                                   /*layout=*/nullptr);
+
+      rewriter.eraseOp(writeOp);
+      return success();
+    }
+
     auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
         rewriter, loc, writeOp.getBase(),
         getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 811b09b011e47..1b38a1a27ce70 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
@@ -15,6 +17,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 
 using std::optional;
@@ -121,6 +124,20 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
   return coordinates;
 }
 
+// Checks if the given memref type represents shared local memory (SLM).
+bool XeGPUDialect::isSharedMemory(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (!attr)
+    return false; // Default memory space is not shared local memory
+  if (auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr))
+    return intAttr.getInt() == 3;
+  if (auto memrefSpace = llvm::dyn_cast_if_present<MemorySpaceAttr>(attr))
+    return memrefSpace.getValue() == MemorySpace::SLM;
+  if (auto xevmSpace = llvm::dyn_cast_if_present<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 51ce6ce53a2fe..3b214fd583b1a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -8,7 +8,6 @@
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -23,17 +22,6 @@
 using namespace mlir;
 using namespace mlir::xegpu;
 
-static bool isSharedMemory(const MemRefType &memrefTy) {
-  Attribute attr = memrefTy.getMemorySpace();
-  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
-    return intAttr.getInt() == 3;
-  if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
-    return memrefSpace.getValue() == MemorySpace::SLM;
-  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
-    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
-  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
-}
-
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 1a19c8a13f120..f7bcbe17c542e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -116,8 +116,8 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 
 // LOAD-ND-LABEL:  @load_transposed(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
-// LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
-// LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
+// LOAD-ND-SAME:   %[[OFFSET1:.+]]: index,
+// LOAD-ND-SAME:   %[[OFFSET2:.+]]: index
 // LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
 // LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]]
@@ -221,7 +221,7 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
 // LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
 // LOAD-GATHER-DAG:    %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -439,7 +439,7 @@ gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-ND:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
-// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // LOAD-ND:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
 // LOAD-ND:        %[[STEP:.+]] = vector.step : vector<8xindex>
 // LOAD-ND:        arith.muli {{.*}} : index
@@ -455,7 +455,7 @@ gpu.func @load_from_subview_1D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-GATHER-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
 // LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
-// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>> 
+// LOAD-GATHER:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // LOAD-GATHER:        %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
 // LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
 // LOAD-GATHER:        arith.muli {{.*}} : index
@@ -510,3 +510,47 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-GATHER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
 // LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector_addrspace3(%source: memref<16x32xf32, 3>,
+    %offset: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %offset], %c0
+    {in_bounds = [true, true]} : memref<16x32xf32, 3>, vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+
+// LOAD-ND-LABEL: @load_2D_vector_addrspace3
+// LOAD-ND-SAME: %[[SOURCE:.+]]: memref<16x32xf32, 3>
+// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[MEM_DESC:.+]] = xegpu.create_mem_desc %[[SOURCE]] : memref<16x32xf32, 3> -> !xegpu.mem_desc<16x32xf32>
+// LOAD-ND: %[[DATA:.+]] = xegpu.load_matrix %[[MEM_DESC]][%[[OFFSET]], %[[OFFSET]]] : !xegpu.mem_desc<16x32xf32>, index, index -> vector<8x16xf32>
+// LOAD-ND: gpu.return %[[DATA]] : vector<8x16xf32>
+
+// LOAD-GATHER-LABEL: @load_2D_vector_addrspace3
+// LOAD-GATHER-SAME: %[[SOURCE:.+]]: memref<16x32xf32, 3>
+// LOAD-GATHER-SAME: %[[OFFSET:.+]]: index
+// LOAD-GATHER: %[[MEM_DESC:.+]] = xegpu.create_mem_desc %[[SOURCE]] : memref<16x32xf32, 3> -> !xegpu.mem_desc<16x32xf32>
+// LOAD-GATHER: %[[DATA:.+]] = xegpu.load_matrix %[[MEM_DESC]][%[[OFFSET]], %[[OFFSET]]] : !xegpu.mem_desc<16x32xf32>, index, index -> vector<8x16xf32>
+// LOAD-GATHER: gpu.return %[[DATA]] : vector<8x16xf32>
+
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_1D_vector_addrspace3_unsupported(%source: memref<32xf32, 3>,
+    %offset: index) -> vector<8xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset], %c0
+    {in_bounds = [true]} : memref<32xf32, 3>, vector<8xf32>
+  gpu.return %0 : vector<8xf32>
+}
+
+// LOAD-ND-LABEL: @load_1D_vector_addrspace3_unsupported
+// LOAD-ND: vector.transfer_read
+
+// LOAD-GATHER-LABEL: @load_1D_vector_addrspace3_unsupported
+// LOAD-GATHER: vector.transfer_read
+
+}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 66da64225678e..9dab33bc78e89 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -344,3 +344,49 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-SCATTER:        %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
 // STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf16>, i64, vector<8xindex>, vector<8xi1>
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_2D_vector_addrspace3(%vec: vector<8x16xf32>,
+    %source: memref<16x32xf32, 3>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset]
+    {in_bounds = [true, true]}
+    : vector<8x16xf32>, memref<16x32xf32, 3>
+  gpu.return
+}
+
+// STORE-ND-LABEL: @store_2D_vector_addrspace3
+// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>
+// STORE-ND-SAME: %[[SOURCE:.+]]: memref<16x32xf32, 3>
+// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[MEM_DESC:.+]] = xegpu.create_mem_desc %[[SOURCE]] : memref<16x32xf32, 3> -> !xegpu.mem_desc<16x32xf32>
+// STORE-ND: xegpu.store_matrix %[[VEC]], %[[MEM_DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>, !xegpu.mem_desc<16x32xf32>, index, index
+// STORE-ND: gpu.return
+
+// STORE-SCATTER-LABEL: @store_2D_vector_addrspace3
+// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>
+// STORE-SCATTER-SAME: %[[SOURCE:.+]]: memref<16x32xf32, 3>
+// STORE-SCATTER-SAME: %[[OFFSET:.+]]: index
+// STORE-SCATTER: %[[MEM_DESC:.+]] = xegpu.create_mem_desc %[[SOURCE]] : memref<16x32xf32, 3> -> !xegpu.mem_desc<16x32xf32>
+// STORE-SCATTER: xegpu.store_matrix %[[VEC]], %[[MEM_DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>, !xegpu.mem_desc<16x32xf32>, index, index
+// STORE-SCATTER: gpu.return
+
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @store_1D_vector_addrspace3_unsupported(%vec: vector<8xf32>,
+    %source: memref<32xf32, 3>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset]
+    {in_bounds = [true]}
+    : vector<8xf32>, memref<32xf32, 3>
+  gpu.return
+}
+
+// STORE-ND-LABEL: @store_1D_vector_addrspace3_unsupported
+// STORE-ND: vector.transfer_write
+
+// STORE-SCATTER-LABEL: @store_1D_vector_addrspace3_unsupported
+// STORE-SCATTER: vector.transfer_write
+
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/192757


More information about the Mlir-commits mailing list