[Mlir-commits] [mlir] b258f5c - [MLIR][Conversion] XeGPU to XeVM: Lower ranked dynamic base memory for create_nd_tdesc. (#164283)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 29 09:55:54 PDT 2025


Author: Sang Ik Lee
Date: 2025-10-29T09:55:50-07:00
New Revision: b258f5c2527f895dfa98921b61482978aba407e0

URL: https://github.com/llvm/llvm-project/commit/b258f5c2527f895dfa98921b61482978aba407e0
DIFF: https://github.com/llvm/llvm-project/commit/b258f5c2527f895dfa98921b61482978aba407e0.diff

LOG: [MLIR][Conversion] XeGPU to XeVM: Lower ranked dynamic base memory for create_nd_tdesc. (#164283)

Current lowering pattern for create_nd_tdesc restricts source memref to
static shape.
In case of a dynamic ranked memref, create_nd_tdesc already provides
shape as an argument.
Lowering can use those values instead of returning a mismatch error.

Added: 
    

Modified: 
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
    mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..33e8f2ed1f6ed 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern
     // If source is a memref, we need to extract the aligned pointer as index.
     // Pointer type is passed as i32 or i64 by type converter.
     if (sourceMemrefTy) {
-      if (!sourceMemrefTy.hasStaticShape()) {
-        return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+      if (!sourceMemrefTy.hasRank()) {
+        return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
       }
       baseAddr =
           memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index d6e36fa73bf04..09ef76c9d1740 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -4,8 +4,9 @@ gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
   // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
+  // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
-  %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
+  %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -43,6 +44,28 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[C1:.*]] = arith.constant 1 : index
+        %c1 = arith.constant 1 : index
+        // CHECK: %[[C64:.*]] = arith.constant 64 : index
+        %size_x = arith.constant 64 : index
+        // CHECK: %[[C16:.*]] = arith.constant 16 : index
+        %BLOCK_DMODEL = arith.constant 16 : index
+        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+        // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
+        // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
+        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+        %dyn_tdesc  = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
         gpu.return
     }
 }


        


More information about the Mlir-commits mailing list