[Mlir-commits] [mlir] [MLIR] Fix issues with XeGPU to XeVM pass. (PR #155946)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 28 16:23:10 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

<details>
<summary>Changes</summary>

Fixes two issue with XeGPU to XeVM pass

1. xegpu.update_nd_offset op lower generated incorrect code sequence
2. xegpu.store_nd did not lower single element vector

---
Full diff: https://github.com/llvm/llvm-project/pull/155946.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+15-10) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+30-16) 


``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
     // Only 2D offsets are supported for now.
     if (mixedOffsets.size() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
-    auto tdesc = adaptor.getTensorDesc();
+    auto payload = adaptor.getTensorDesc();
     // Utility for updating payload offset values from op fold result.
     auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
       Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
       offset = getValueOrCreateCastToIndexLike(rewriter, loc,
                                                rewriter.getI32Type(), offset);
       Value oldOffset =
-          vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+          vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
       Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+      return vector::InsertOp::create(rewriter, loc, newOffset, payload,
                                       payloadPos);
     };
     // Update offsets in the payload.
-    auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
-    rewriter.replaceOp(op, val);
+    payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+    rewriter.replaceOp(op, payload);
     return success();
   }
 };
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     auto tileH = tdescTy.getDimSize(0);
     int32_t vblocks = tdescTy.getArrayLength();
     if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+      Value src = adaptor.getValue();
+      // If store value is a scalar, get value from op instead of adaptor.
+      // Adaptor might have optimized away single element vector
+      if (src.getType().isIntOrFloat()) {
+        src = op.getValue();
+      }
+      VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
       if (!srcVecTy)
         return rewriter.notifyMatchFailure(
             op, "Expected store value to be a vector type.");
-      auto storeCacheControl =
-          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      Value src = adaptor.getValue();
       // Get flat vector type of integer type with matching element bit size.
       VectorType newSrcVecTy =
           encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
       if (srcVecTy != newSrcVecTy)
         src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
       xevm::BlockStore2dOp::create(
           rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
           offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
 
 gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
-  // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+  // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
-  gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+  gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+        %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
 
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
         // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
-        // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
-        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
         // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
-        // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
-        // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+        // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+        // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
         // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
         // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
-        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[C8:.*]] = arith.constant 8 : index
+        %c8 = arith.constant 8 : index
+        // CHECK: %[[C16:.*]] = arith.constant 16 : index
+        %c16 = arith.constant 16 : index
+        // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+        // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+        // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+        %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
         gpu.return
     }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/155946


More information about the Mlir-commits mailing list