[Mlir-commits] [mlir] f25f2f7 - [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 17 15:46:38 PDT 2025
Author: Jianhui Li
Date: 2025-06-17T17:46:35-05:00
New Revision: f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f
URL: https://github.com/llvm/llvm-project/commit/f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f
DIFF: https://github.com/llvm/llvm-project/commit/f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f.diff
LOG: [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)
Add support for load/store with chunk_size, which requires special
consideration for the operand blocking since offests and masks are
n-D and tensor are n+1-D. Support operations including create_tdesc,
update_tdesc, load, store, and prefetch.
---------
Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 9c234c1e866b9..0457f8128b908 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
- VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+ if (originalChunkSize > 1)
+ targetIndiceShape.pop_back();
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
SmallVector<Type> convertedIndiceTypes =
- getUnrolledTypes(indiceVecTy, *targetShape);
+ getUnrolledTypes(indiceVecTy, targetIndiceShape);
SmallVector<Value> convertedIndiceVec =
- pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+ pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
SmallVector<Value> newOps;
- for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
- op.getSource(), indice);
- newOps.push_back(newOp);
+
+ // More indices is need when chunkSize > 1. Since a big load from one
+ // address could be break into multiple small loads.
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+ for (auto [indice, indiceType] :
+ llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ // Compute the offset
+ Value inc = rewriter.create<arith::ConstantIndexOp>(
+ loc, i * blockedChunkSize);
+ Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
+ Value offsetIndice =
+ rewriter.create<arith::AddIOp>(loc, indice, incVec);
+
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(
+ loc, newTdescTy, op.getSource(), offsetIndice);
+
+ newOps.push_back(newOp);
+ }
+ }
+ } else {
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(
+ loc, newTdescTy, op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
}
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (originalChunkSize > 1) {
+ targetMaskShape.pop_back();
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numNewChunks; ++i)
+ convertedMasks.push_back(mask);
+ }
+ // This is to handle the transpose effect when chunkSize > 1.
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+ newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
+ loc, rewriter);
+ }
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
rewriter.replaceOp(op, castOp);
return success();
}
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- SmallVector<Type> convertedValTypes =
- getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
-
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+ convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ convertedMasks.push_back(mask);
+ }
+ }
+ // This is to handle the transpose effect when chunkSize > 1.
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
+ return failure();
+
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
- SmallVector<Type> convertedOffsetTypes =
- getUnrolledTypes(offsetVecTy, *targetShape);
- SmallVector<Value> convertedOffsetVec =
- pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
-
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+ if (originalChunkSize > 1) {
+ SmallVector<int64_t> shape1D(targetShape->begin(),
+ targetShape->end() - 1);
+ convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
+ SmallVector<Value> convertedOffsetVec1D =
+ pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+ for (auto offset : convertedOffsetVec1D) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ convertedOffsetVec.push_back(offset);
+ }
+ }
+
+ } else {
+ convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
+ convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 52ec3b856da49..41414d802f212 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -2,7 +2,7 @@
gpu.module @test {
- // CHECK-LABEL: test_create_nd_tdesc
+ // CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
@@ -10,31 +10,31 @@ gpu.module @test {
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
- gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
}
//-----
- // CHECK-LABEL: test_create_nd_tdesc_1d
+ // CHECK-LABEL: create_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
- gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
}
//-----
- // CHECK-LABEL: test_update_nd_tdesc
+ // CHECK-LABEL: update_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -42,11 +42,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_update_nd_tdesc_1d
+ // CHECK-LABEL: update_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
- gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
%update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
@@ -54,11 +54,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch_nd_tdesc
+ // CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return
@@ -66,23 +66,23 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+ // CHECK-LABEL: prefetch_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
- gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+ gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
gpu.return
}
//-----
- // CHECK-LABEL: test_load_nd
+ // CHECK-LABEL: load_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
- gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+ gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
@@ -90,12 +90,12 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_load_nd_1d
+ // CHECK-LABEL: load_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
// CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
- gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+ gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
gpu.return %data : vector<64xf32>
@@ -103,11 +103,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store_nd
+ // CHECK-LABEL: store_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+ gpu.func @store_nd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -116,11 +116,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store_nd_1d
+ // CHECK-LABEL: store_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
- gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+ gpu.func @store_nd_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = arith.constant dense<9.0> : vector<64xf32>
xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
@@ -129,7 +129,7 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_createNd_loadNd_storeNd
+ // CHECK-LABEL: createNd_loadNd_storeNd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
//CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
@@ -137,7 +137,7 @@ gpu.module @test {
//CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
//CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
//CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+ gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
@@ -148,23 +148,23 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_dpas
+ // CHECK-LABEL: dpas
// CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
//CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
//CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
//CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
//CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
- gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+ gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
//-----
- // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-LABEL: create_tdesc_vec
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -177,10 +177,10 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-LABEL: create_tdesc_step
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
@@ -190,11 +190,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_load
+ // CHECK-LABEL: load
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
- gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ gpu.func @load(%src: ui64) -> vector<32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -212,11 +212,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch
+ // CHECK-LABEL: prefetch
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_prefetch(%src: ui64) {
+ gpu.func @prefetch(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
@@ -233,11 +233,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store
+ // CHECK-LABEL: store
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
- gpu.func @test_store(%src: ui64) {
+ gpu.func @store(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -256,47 +256,129 @@ gpu.module @test {
}
//-----
+ // CHECK-LABEL: create_tdesc_step_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
+ gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+ }
- // CHECK-LABEL: test_prefetch_load_store_update
+//-----
+ // CHECK-LABEL: create_tdesc_step_chunk2
// CHECK-SAME: [[arg0:%.+]]: ui64
- // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
- // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
- // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ }
- gpu.func @test_prefetch_load_store_update(%src: ui64) {
+// CHECK-LABEL: create_tdesc_step_chunk3
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
+ %step = arith.constant dense<8> : vector<16xindex>
+ %seq = vector.step : vector<16xindex>
+ %cst = arith.muli %seq, %step : vector<16xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+ }
+
+//-----
+ // CHECK-LABEL: load_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+ gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
%cst = arith.constant dense<[
- 0, 8, 16, 24, 32, 40, 48, 56,
- 64, 72, 80, 88, 96, 104, 112, 120,
- 128, 136, 144, 152, 160, 168, 176, 184,
- 192, 200, 208, 216, 224, 232, 240, 248
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
- %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-
- %delta = arith.constant dense<[
- 32, 32, 32, 32, 32, 32, 32, 32,
- 32, 32, 32, 32, 32, 32, 32, 64,
- 128, 128, 128, 128, 128, 128, 128, 128,
- 128, 128, 128, 128, 128, 128, 128, 256
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+ gpu.return %ld : vector<4x32xf32>
+ }
+
+//-----
+ // CHECK-LABEL: store_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+ gpu.func @store_chunk(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>
- %new_tdesc = xegpu.update_offset %tdesc, %delta
- : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
-
+
%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
- %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+ %st_vec = arith.constant dense<1023.>: vector<4x32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
+
+ gpu.return
+ }
- %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
- xegpu.store %st_vec, %tdesc, %mask:
- vector<32xf32>,
- !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
- vector<32xi1>
-
+//-----
+ // CHECK-LABEL: prefetch_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @prefetch_chunk(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
gpu.return
}
+
+//-----
+ // CHECK-LABEL: update_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+ gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %delta = arith.constant dense<32>: vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+ gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ }
}
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 57aaecbd7962f..4400d6d9625f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,6 +19,10 @@ using namespace mlir::xegpu;
namespace {
+#define DEBUG_TYPE "test-xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
struct TestXeGPUUnrollingPatterns
: public PassWrapper<TestXeGPUUnrollingPatterns,
OperationPass<gpu::GPUModuleOp>> {
@@ -48,7 +52,9 @@ struct TestXeGPUUnrollingPatterns
options.setNativeShapeFn(
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
tdescTy = createNdOp.getType();
@@ -61,20 +67,7 @@ struct TestXeGPUUnrollingPatterns
tdescTy = loadNdOp.getTensorDescType();
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
tdescTy = storeNdOp.getTensorDescType();
- }
-
- if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isSgLayout())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
- }
- }
-
- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
- xegpu::TensorDescType tdescTy;
- if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
@@ -111,14 +104,40 @@ struct TestXeGPUUnrollingPatterns
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
+
+ // If the encoding is a ScatterTensorDescAttr, we need to
+ // potentially adjust the chunk size based on the inst_data.
+ if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr =
+ mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+ int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+ if (chunkSize > 1) {
+ int64_t blockedChunkSize = chunkSize;
+ auto instData = layout.getInstData();
+ if (!instData.empty())
+ blockedChunkSize = instData.asArrayRef().back();
+
+ auto chunkSizeAttr = mlir::IntegerAttr::get(
+ mlir::IntegerType::get(ctx, 64), blockedChunkSize);
+
+ // To create a new attribute with a
diff erent chunk_size:
+ auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+ ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+ encoding = newEncoding;
+ }
+ }
if (layout) {
if (layout.getLaneLayout() == nullptr)
layout = xegpu::LayoutAttr();
else
layout = layout.dropInstData();
}
+
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
layout);
+
} else {
newTy = type.clone(tileShape, elemTy);
}
More information about the Mlir-commits
mailing list