[Mlir-commits] [mlir] [mlir][xegpu] Add dynamic memref support in transpose optimization. (PR #170218)

Charitha Saumya llvmlistbot at llvm.org
Tue Dec 2 15:51:32 PST 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/170218

>From 912b0fec54a3ffc3d20d1623fde2acd10b3b0c1a Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 1 Dec 2025 22:28:07 +0000
Subject: [PATCH 1/3] add tests and code changes

---
 .../Transforms/XeGPUOptimizeBlockLoads.cpp    |  2 +-
 ...anspose.mlir => optimize-block-loads.mlir} | 29 +++++++++++++++++++
 2 files changed, 30 insertions(+), 1 deletion(-)
 rename mlir/test/Dialect/XeGPU/{optimize-transpose.mlir => optimize-block-loads.mlir} (91%)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
index 4dc5ea4f7bb24..f1956bd75bcf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -284,7 +284,7 @@ class XeGPUCreateNdDescOpPattern final
 
     // If the source is a static memref, we need to extract the pointer to
     // base address.
-    if (memrefType && memrefType.hasStaticShape()) {
+    if (memrefType) {
       auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
           rewriter, createNdOp.getLoc(), source);
       source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
similarity index 91%
rename from mlir/test/Dialect/XeGPU/optimize-transpose.mlir
rename to mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
index 24a0de6ed48a5..6eaa82f42d02c 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
@@ -278,3 +278,32 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
   gpu.return
 }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @dynamic_memref(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 16 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK-NEXT:    %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?xf16> -> index
+// CHECK-NEXT:    %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK-NEXT:    %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT:    %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]  {layout_result_0 =
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x8xi32,
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:    %{{.*}} = vector.bitcast %[[T2]] {layout_result_0 =
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @dynamic_memref(%arg0: memref<?x?xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+  %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  gpu.return %6 : vector<8x16xf32>
+}
+}

>From b83f9dca1acd38cfceb8d418589a38feac92a97c Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 1 Dec 2025 22:39:12 +0000
Subject: [PATCH 2/3] add tests and code changes

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
index f1956bd75bcf2..1642e9829a79d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -282,7 +282,7 @@ class XeGPUCreateNdDescOpPattern final
                        modifiedStrides[modifiedStrides.size() - 2]),
         innerLaneData);
 
-    // If the source is a static memref, we need to extract the pointer to
+    // If the source is a memref, we need to extract the pointer to
     // base address.
     if (memrefType) {
       auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(

>From 5bbac3b1ea850011955cca06683812c95aca808e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 1 Dec 2025 23:06:28 +0000
Subject: [PATCH 3/3] fix test

---
 mlir/test/Dialect/XeGPU/optimize-block-loads.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
index 6eaa82f42d02c..526adc5a95d10 100644
--- a/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
@@ -282,7 +282,7 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
 // -----
 // CHECK-LABEL: gpu.func @dynamic_memref(
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
-// CHECK-DAG:     %[[C32:.*]] = arith.constant 16 : index
+// CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
 // CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
 // CHECK-NEXT:    %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?xf16> -> index
 // CHECK-NEXT:    %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64



More information about the Mlir-commits mailing list