[Mlir-commits] [mlir] [MLIR] Fix issues with XeGPU to XeVM pass. (PR #155946)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 28 16:23:10 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Sang Ik Lee (silee2)
<details>
<summary>Changes</summary>
Fixes two issue with XeGPU to XeVM pass
1. xegpu.update_nd_offset op lower generated incorrect code sequence
2. xegpu.store_nd did not lower single element vector
---
Full diff: https://github.com/llvm/llvm-project/pull/155946.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+15-10)
- (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+30-16)
``````````diff
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
- auto tdesc = adaptor.getTensorDesc();
+ auto payload = adaptor.getTensorDesc();
// Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offset);
Value oldOffset =
- vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+ vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+ return vector::InsertOp::create(rewriter, loc, newOffset, payload,
payloadPos);
};
// Update offsets in the payload.
- auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
- val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
- rewriter.replaceOp(op, val);
+ payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ rewriter.replaceOp(op, payload);
return success();
}
};
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- Value src = adaptor.getValue();
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
- gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+ gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
- %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+ %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
- // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
- %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
- // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
- // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+ // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+ // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ %c8 = arith.constant 8 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/155946
More information about the Mlir-commits
mailing list