[Mlir-commits] [mlir] [MLIR][Conversion] XeGPU to xevm1 d block (PR #165894)
Sang Ik Lee
llvmlistbot at llvm.org
Fri Oct 31 11:03:58 PDT 2025
https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/165894
None
>From 319205c9384078d663e375ed5fae57ea1db48b3a Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 30 Oct 2025 23:07:05 +0000
Subject: [PATCH 1/2] Isolate 2D code.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 128 +++++++++---------
1 file changed, 64 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2ed1f6ed..a869ce09181a8 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -264,8 +264,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
if (elemBitSize % 8 != 0)
@@ -294,71 +292,73 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- // Convert base pointer (i64) to LLVM pointer type.
- Value basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
- Value elemByteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
- Value surfaceW =
- arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
- auto tileH = tdescTy.getDimSize(0);
- int32_t vblocks = tdescTy.getArrayLength();
- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- Value src = adaptor.getValue();
- // If store value is a scalar, get value from op instead of adaptor.
- // Adaptor might have optimized away single element vector
- if (src.getType().isIntOrFloat()) {
- src = op.getValue();
- }
- VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
- if (!srcVecTy)
- return rewriter.notifyMatchFailure(
- op, "Expected store value to be a vector type.");
- // Get flat vector type of integer type with matching element bit size.
- VectorType newSrcVecTy =
- encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
- if (srcVecTy != newSrcVecTy)
- src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
- xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
- rewriter.eraseOp(op);
- } else {
- auto loadCacheControl =
- translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
- xevm::BlockPrefetch2dOp::create(
+ if (tdescTy.getRank() == 2) {
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // Compute element byte size and surface width in bytes.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ // Get tile sizes and vblocks from the tensor descriptor type.
+ auto tileW = tdescTy.getDimSize(1);
+ auto tileH = tdescTy.getDimSize(0);
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
- VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
- const bool vnni = op.getPacked().value_or(false);
- auto transposeValue = op.getTranspose();
- bool transpose =
- transposeValue.has_value() && transposeValue.value()[0] == 1;
- VectorType loadedTy = encodeVectorTypeTo(
- dstVecTy, vnni ? rewriter.getI32Type()
- : rewriter.getIntegerType(elemBitSize));
-
- Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
- resultFlatVec = vector::BitCastOp::create(
- rewriter, loc,
- encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
- resultFlatVec);
- rewriter.replaceOp(op, resultFlatVec);
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
+ offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
}
}
return success();
>From fd3186e3311fbc314963fd2372aad117113311c8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 30 Oct 2025 23:33:40 +0000
Subject: [PATCH 2/2] Add handler for 1D block load_nd/store_nd and test case.
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 73 ++++++++++++++++++-
1 file changed, 69 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a869ce09181a8..bc360c1a6929b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -292,19 +292,25 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- if (tdescTy.getRank() == 2) {
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ auto tileRank = tdescTy.getRank();
+ // Get tile width from the tensor descriptor type.
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ if (tileRank == 2) {
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
+ // Compute width in bytes.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
+ // Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
+ // Get vblocks from the tensor descriptor type.
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
@@ -360,6 +366,65 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter.replaceOp(op, resultFlatVec);
}
}
+ } else {
+ // Get address from base address and offsets.
+ // Offset in number of elements, need to multiply by element byte size.
+ // Compute linear offset.
+ // linearOffset = offsetH * baseShapeW + offsetW
+ Value offsetHInElems =
+ rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
+ Value linearOffset =
+ rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
+ // Then compute byte offset by multiplying with element byte size.
+ // byteOffset = linearOffset * elemByteSize
+ Value byteOffset =
+ rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
+ // Final address = basePtr + byteOffset
+ Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+ loc, basePtr,
+ getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+ byteOffset));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value finalPtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+ op, finalPtrLLVM, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType resTy = cast<VectorType>(op.getValue().getType());
+ VectorType loadedTy =
+ encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+ Value load = xevm::BlockLoadOp::create(
+ rewriter, loc, loadedTy, finalPtrLLVM,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ if (loadedTy != resTy)
+ load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+ rewriter.replaceOp(op, load);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+ "descriptor rank == 1");
+ }
}
return success();
}
More information about the Mlir-commits
mailing list