[Mlir-commits] [mlir] 87e094d - [MLIR][Conversion] XeGPU to XeVM: Add handler for 1D block ops (#165894)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 10 09:02:37 PST 2025
Author: Sang Ik Lee
Date: 2025-11-10T09:02:32-08:00
New Revision: 87e094da0f2bfc5b9e40532bd75db9a4a9da6a71
URL: https://github.com/llvm/llvm-project/commit/87e094da0f2bfc5b9e40532bd75db9a4a9da6a71
DIFF: https://github.com/llvm/llvm-project/commit/87e094da0f2bfc5b9e40532bd75db9a4a9da6a71.diff
LOG: [MLIR][Conversion] XeGPU to XeVM: Add handler for 1D block ops (#165894)
Add lowering for xegpu load_nd / store_nd with 1D tensor descriptor.
Add conversion test case.
Added:
mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
Modified:
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index de552ce22ef62..705298f497d20 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ // Cast index to i64.
+ baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
} else {
baseAddr = adaptor.getSource();
+ if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ }
+ // 1D tensor descriptor is just the base address.
+ if (rank == 1) {
+ rewriter.replaceOp(op, baseAddr);
+ return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
- if (sourceMemrefTy) {
- // Cast index to i64.
- baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else if (baseAddr.getType() != i64Ty) {
- // Pointer type may be i32. Cast to i64 if needed.
- baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
- }
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -257,108 +258,175 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto tileRank = tdescTy.getRank();
+ if (opOffsetsSize != tileRank)
+ return rewriter.notifyMatchFailure(
+ op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
if (elemBitSize % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
- VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
- Value payLoadAsI64 =
- vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
- Value basePtr = vector::ExtractOp::create(
- rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
- Value baseShapeW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
- Value baseShapeH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets are provided by the op.
- // convert them to i32.
- Value offsetW =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- Value offsetH =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- // Convert base pointer (i64) to LLVM pointer type.
- Value basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
- Value elemByteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
- Value surfaceW =
- arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
- auto tileH = tdescTy.getDimSize(0);
- int32_t vblocks = tdescTy.getArrayLength();
- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- Value src = adaptor.getValue();
- // If store value is a scalar, get value from op instead of adaptor.
- // Adaptor might have optimized away single element vector
- if (src.getType().isIntOrFloat()) {
- src = op.getValue();
- }
- VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
- if (!srcVecTy)
- return rewriter.notifyMatchFailure(
- op, "Expected store value to be a vector type.");
- // Get flat vector type of integer type with matching element bit size.
- VectorType newSrcVecTy =
- encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
- if (srcVecTy != newSrcVecTy)
- src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
- xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
- rewriter.eraseOp(op);
- } else {
- auto loadCacheControl =
- translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
- xevm::BlockPrefetch2dOp::create(
+ if (tileRank == 2) {
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // Compute width in bytes.
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ // Get tile width from the tensor descriptor type.
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ // Get tile height from the tensor descriptor type.
+ auto tileH = tdescTy.getDimSize(0);
+ // Get vblocks from the tensor descriptor type.
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
- VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
- const bool vnni = op.getPacked().value_or(false);
- auto transposeValue = op.getTranspose();
- bool transpose =
- transposeValue.has_value() && transposeValue.value()[0] == 1;
- VectorType loadedTy = encodeVectorTypeTo(
- dstVecTy, vnni ? rewriter.getI32Type()
- : rewriter.getIntegerType(elemBitSize));
-
- Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
+ offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+ transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ } else {
+ // 1D tensor descriptor.
+ // `tdesc` represents base address as i64
+ // Offset in number of elements, need to multiply by element byte size.
+ // Compute byte offset.
+ // byteOffset = offset * elementByteSize
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI64Type(), offset);
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
+ Value byteOffset =
+ rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
+ // Final address = basePtr + byteOffset
+ Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+ loc, tdesc,
+ getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+ byteOffset));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value finalPtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+ op, finalPtrLLVM, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType resTy = cast<VectorType>(op.getValue().getType());
+ VectorType loadedTy =
+ encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+ Value load = xevm::BlockLoadOp::create(
+ rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
- resultFlatVec = vector::BitCastOp::create(
- rewriter, loc,
- encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
- resultFlatVec);
- rewriter.replaceOp(op, resultFlatVec);
+ if (loadedTy != resTy)
+ load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+ rewriter.replaceOp(op, load);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+ "descriptor rank == 1");
}
}
return success();
@@ -929,7 +997,10 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ // Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
+ return {};
+ if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 09ef76c9d1740..109312218afae 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
+ // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
- // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
@@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc {
%BLOCK_DMODEL = arith.constant 16 : index
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+ // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
- // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
new file mode 100644
index 0000000000000..7b4ad9ec2df03
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+ // CHECK-LABEL: @load_store(
+ // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1>
+ gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel {
+ // CHECK: %[[C512:.*]] = arith.constant 512 : i64
+ // CHECK: %[[C384:.*]] = arith.constant 384 : i64
+
+ // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
+ %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
+ // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
+ %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
+
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
+ // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
+ // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
+ // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
+ // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
+ // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
+ %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
+
+ // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
+ // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
+ // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
+ // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
+ // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
+ xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+ : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c6bbf25b4728..95774ca67c4f2 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -16,6 +16,7 @@ gpu.module @load_store_check {
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
@@ -25,7 +26,6 @@ gpu.module @load_store_check {
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
//CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
@@ -52,6 +52,7 @@ gpu.module @load_store_check {
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
@@ -61,7 +62,6 @@ gpu.module @load_store_check {
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
More information about the Mlir-commits
mailing list