[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)
Jianhui Li
llvmlistbot at llvm.org
Tue Jun 17 14:51:25 PDT 2025
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/144447
>From bac8bc607a0b5f9171472375e5f42c24a5ad429d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 14:40:41 +0000
Subject: [PATCH 01/16] add unroll supportfor reduce and broadcast
---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 6 +
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 110 ++++++++++++++++++
2 files changed, 116 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7cd998eed2e08..a3826c56e1f62 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
return getTileShape(op->getOpResult(0));
+ if (isa<vector::MultiDimReductionOp>(op))
+ return getTileShape(op->getOpOperand(0));
+
+ if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
+ return getTileShape(op->getOpResult(0));
+
return std::nullopt;
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f9114988686c8..8e3673d04eacb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -246,3 +246,113 @@ gpu.module @test_kernel {
gpu.return
}
}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %acc = arith.constant dense<0.0> : vector<64xf32>
+ %c64 = arith.constant 64 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c64 : index
+ %0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32>
+ // CHECK-COUNT-3: vector.multi_reduction <add>, {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32>
+ %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32>
+ %3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+ xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r>
+ gpu.return
+ }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %acc = arith.constant dense<0.0> : vector<32xf32>
+
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+
+ %m = arith.muli %block_id_x, %c32 : index
+ %n = arith.muli %block_id_y, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32>
+
+ // CHECK: vector.multi_reduction <add>, {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32>
+ // CHECK-COUNT-1: vector.multi_reduction <add>, {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32>
+
+ %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32>
+ %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+ xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r>
+ gpu.return
+ }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c64 = arith.constant 64 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c64 : index
+ %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32>
+ // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32>
+ %2 = vector.broadcast %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+ xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l>
+ gpu.return
+ }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c32 = arith.constant 32 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32>
+ %11 = vector.shape_cast %1 : vector<32xf32> to vector<32x1xf32>
+ // CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32>
+ %2 = vector.broadcast %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l>
+ xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l>
+ gpu.return
+ }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 8]>
+#t = #xegpu.layout<inst_data = [8, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c32 = arith.constant 32 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32>
+ // CHECK-COUNT-2: vector.transpose {{.*}} [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+ %2 = vector.transpose %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t>
+ xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
+ gpu.return
+ }
+}
\ No newline at end of file
>From 30b099ef6d1f8d036290590c73786bd45ea51b57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 23:14:58 +0000
Subject: [PATCH 02/16] add create_desc unrolling and test
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 41 ++++++++++++++++++-
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 24 +++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..672b0fb731f31 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+
+ VectorType indiceVecTy = indiceVec.getType();
+
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp>(
patterns.getContext(), options);
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..abdee098ab430 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,30 @@ struct TestXeGPUUnrollingPatterns
}
}
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
+ xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp =
+ dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};
>From 8a0e1455683d2a04744bb653348e39c353060704 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:49:39 +0000
Subject: [PATCH 03/16] add unrolling support scatter operations
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 159 +++++++++++++++++-
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 143 ++++++++++++++++
2 files changed, 296 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 672b0fb731f31..96b86f6509419 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -409,19 +409,14 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-
VectorType indiceVecTy = indiceVec.getType();
-
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
-
SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
-
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
@@ -434,12 +429,164 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
}
};
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ SmallVector<Value> convertedMasks = pack(
+ op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+ auto newOp =
+ rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,
+ op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+ using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy;
+ if (op.getMask())
+ maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs =
+ pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> convertedMasks;
+ if (op.getMask()) {
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ for (size_t i = 0; i < convertedValues.size(); ++i) {
+ Value v = convertedValues[i];
+ Value t = convertedTdescs[i];
+ Value m = op.getMask() ? convertedMasks[i] : nullptr;
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+ using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+ VectorType offsetVecTy = offsetVec.getType();
+ SmallVector<Type> convertedOffsetTypes =
+ getUnrolledTypes(offsetVecTy, *targetShape);
+ SmallVector<Value> convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+ auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
+ loc, t.getType(), t, o);
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
- UnrollCreateDescOp>(
+ UnrollCreateDescOp, UnrollLoadGatherOp,
+ UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_load
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_prefetch(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_store
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+ gpu.func @test_store(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
}
>From c91156c7fc789548676f367919ddd577ba06c5ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:51:32 +0000
Subject: [PATCH 04/16] clang format
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 31 ++++++++++---------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 7 ++---
2 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 96b86f6509419..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -418,7 +418,8 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
newOps.push_back(newOp);
}
@@ -454,14 +455,14 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks = pack(
- op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
- auto newOp =
- rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,
- op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}
@@ -520,8 +521,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
- SmallVector<Value> convertedTdescs =
- pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedMasks;
if (op.getMask()) {
@@ -566,12 +567,12 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
- pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
- auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
- loc, t.getType(), t, o);
+ auto newOp =
+ rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
newOps.push_back(newOp);
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -585,8 +586,8 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
- UnrollCreateDescOp, UnrollLoadGatherOp,
- UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
- patterns.getContext(), options);
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+ UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+ options);
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index abdee098ab430..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,13 +71,12 @@ struct TestXeGPUUnrollingPatterns
}
}
- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
- xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
- } else if (auto updateOp =
- dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
} else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
tdescTy = prefetchOp.getTensorDescType();
>From 30cb8d816352b0d29578f408777a151c94b5f3be Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 20:59:47 -0700
Subject: [PATCH 05/16] Update
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
---
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 47c54bfcb89d0..cca5339b086e4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -249,7 +249,7 @@ gpu.module @test {
%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
- %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %st_vec = arith.constant dense<1023.0>: vector<32xf32>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
>From f493e52fefa594aa5918471bf0aa57b2cc0331a4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 21:06:13 -0700
Subject: [PATCH 06/16] Update xegpu-unroll-patterns.mlir
correct indentation
---
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 100 +++++++++---------
1 file changed, 49 insertions(+), 51 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index cca5339b086e4..52ec3b856da49 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -190,26 +190,25 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_load
- // CHECK-SAME: [[arg0:%.+]]: ui64
- // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
- gpu.func @test_load(%src: ui64) -> vector<32xf32> {
- %cst = arith.constant dense<[
- 0, 8, 16, 24, 32, 40, 48, 56,
- 64, 72, 80, 88, 96, 104, 112, 120,
- 128, 136, 144, 152, 160, 168, 176, 184,
- 192, 200, 208, 216, 224, 232, 240, 248
- ]> : vector<32xindex>
+ // CHECK-LABEL: test_load
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
- %c17 = arith.constant 17: index
- %mask = vector.create_mask %c17: vector<32xi1>
-
- %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
- gpu.return %ld : vector<32xf32>
- }
+ gpu.return %ld : vector<32xf32>
+ }
//-----
@@ -219,17 +218,17 @@ gpu.module @test {
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
gpu.func @test_prefetch(%src: ui64) {
- %cst = arith.constant dense<[
- 0, 8, 16, 24, 32, 40, 48, 56,
- 64, 72, 80, 88, 96, 104, 112, 120,
- 128, 136, 144, 152, 160, 168, 176, 184,
- 192, 200, 208, 216, 224, 232, 240, 248
- ]> : vector<32xindex>
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
- %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- gpu.return
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return
}
//-----
@@ -268,37 +267,36 @@ gpu.module @test {
gpu.func @test_prefetch_load_store_update(%src: ui64) {
- %cst = arith.constant dense<[
- 0, 8, 16, 24, 32, 40, 48, 56,
- 64, 72, 80, 88, 96, 104, 112, 120,
- 128, 136, 144, 152, 160, 168, 176, 184,
- 192, 200, 208, 216, 224, 232, 240, 248
- ]> : vector<32xindex>
-
- %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
- xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- %delta = arith.constant dense<[
- 32, 32, 32, 32, 32, 32, 32, 32,
- 32, 32, 32, 32, 32, 32, 32, 64,
- 128, 128, 128, 128, 128, 128, 128, 128,
- 128, 128, 128, 128, 128, 128, 128, 256
- ]> : vector<32xindex>
- %new_tdesc = xegpu.update_offset %tdesc, %delta
- : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
- %c17 = arith.constant 17: index
- %mask = vector.create_mask %c17: vector<32xi1>
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
- %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
- %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
- xegpu.store %st_vec, %tdesc, %mask:
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
vector<32xf32>,
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
vector<32xi1>
- gpu.return
- }
+ gpu.return
+ }
}
>From 2606a4bf20ab35ba9b46054c4dc4255d30619e75 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:21:42 +0000
Subject: [PATCH 07/16] reject support for chunk_size
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 36 ++++++++++++++-----
1 file changed, 28 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 900ade8c171d5..c8e1479f49f11 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,6 +403,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
+ //check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1) {
+ return failure();
+ }
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
@@ -439,6 +444,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ //check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1) {
+ return failure();
+ }
+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -480,6 +490,11 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ //check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1) {
+ return failure();
+ }
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
@@ -506,9 +521,12 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- VectorType maskTy;
- if (op.getMask())
- maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+ //check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1) {
+ return failure();
+ }
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
@@ -524,13 +542,10 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Value> convertedMasks;
- if (op.getMask()) {
- SmallVector<Type> convertedMaskTypes =
+ SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
- convertedMasks =
+ SmallVector<Value> convertedMasks =
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
- }
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -553,6 +568,11 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
+ //check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1) {
+ return failure();
+ }
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
>From a3e064d4f051a46dbee8960d86685aca0700bbcc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:30:40 +0000
Subject: [PATCH 08/16] small fixes
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 29 ++++++++-----------
1 file changed, 12 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8e1479f49f11..8a0048fbca8ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,10 +403,9 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
- //check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1) {
+ // check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1)
return failure();
- }
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
@@ -444,10 +443,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- //check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1) {
+ // check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1)
return failure();
- }
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -490,10 +488,9 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- //check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1) {
+ // check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1)
return failure();
- }
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
@@ -521,10 +518,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- //check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1) {
+ // check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1)
return failure();
- }
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -543,9 +539,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
+ getUnrolledTypes(maskTy, *targetShape);
SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -568,10 +564,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- //check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1) {
+ // check if the tensor descriptor type is a 1d vector type
+ if (tdescTy.getRank() > 1)
return failure();
- }
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
>From b3d59370fde678cff0c34130f2e5cc71900e0d66 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 14 Jun 2025 01:38:51 +0000
Subject: [PATCH 09/16] smal fix
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8a0048fbca8ed..9c234c1e866b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -415,6 +415,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
+
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
SmallVector<Value> convertedIndiceVec =
>From 248981b2f39a4c8b462735d6d894c93e5ff0c9da Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 15 Jun 2025 04:05:18 +0000
Subject: [PATCH 10/16] add support for create_tdesc with chunk_size
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 53 +++++++++++++----
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 58 +++++++++++++------
2 files changed, 81 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 9c234c1e866b9..8241866c005bb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -404,7 +404,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
xegpu::TensorDescType tdescTy = op.getType();
// check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -416,16 +416,47 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
VectorType indiceVecTy = indiceVec.getType();
- SmallVector<Type> convertedIndiceTypes =
- getUnrolledTypes(indiceVecTy, *targetShape);
- SmallVector<Value> convertedIndiceVec =
- pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
-
+ SmallVector<Type> convertedIndiceTypes;
+ SmallVector<Value> convertedIndiceVec;
SmallVector<Value> newOps;
- for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+
+ if (tdescTy.getRank() == 2) {
+ SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
+ convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
+ convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
+ // Assume tdescTy, targetShape, and convertedIndiceVec are defined
+ int64_t outerDim = tdescTy.getShape().back();
+ int64_t innerDim = targetShape->back();
+ int64_t numInnerLoops = outerDim / innerDim;
+
+ // Get element size in bytes
+ int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
+
+ for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+ for (int64_t i = 0; i < numInnerLoops; ++i) {
+ // Compute the offset
+ Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
+ Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
+ Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
+
+ auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(
+ loc, newTdescTy, op.getSource(), offsetIndice);
+
+ newOps.push_back(newOp);
+ }
+ }
+ } else if (tdescTy.getRank() == 1) {
+ convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
+ convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
- newOps.push_back(newOp);
+ newOps.push_back(newOp);
+ }
+ } else {
+ // Unsupported rank for tensor descriptor
+ return failure();
}
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -445,9 +476,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
return failure();
-
+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 57aaecbd7962f..758e774eb8e01 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,6 +19,12 @@ using namespace mlir::xegpu;
namespace {
+
+#define DEBUG_TYPE "test-xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+
struct TestXeGPUUnrollingPatterns
: public PassWrapper<TestXeGPUUnrollingPatterns,
OperationPass<gpu::GPUModuleOp>> {
@@ -48,12 +54,13 @@ struct TestXeGPUUnrollingPatterns
options.setNativeShapeFn(
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
- xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+ xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+ xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
tdescTy = createNdOp.getType();
- } else if (auto updateNdOp =
- dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
tdescTy = updateNdOp.getTensorDescType();
} else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
tdescTy = prefetchNdOp.getTensorDescType();
@@ -61,20 +68,7 @@ struct TestXeGPUUnrollingPatterns
tdescTy = loadNdOp.getTensorDescType();
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
tdescTy = storeNdOp.getTensorDescType();
- }
-
- if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isSgLayout())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
- }
- }
-
- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
- xegpu::TensorDescType tdescTy;
- if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
} else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
@@ -111,14 +105,40 @@ struct TestXeGPUUnrollingPatterns
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
+
+ int64_t newChunkSize = 0;
+ auto instData = layout.getInstData();
+ if (!instData.empty())
+ newChunkSize = instData.asArrayRef().back();
+
if (layout) {
if (layout.getLaneLayout() == nullptr)
layout = xegpu::LayoutAttr();
else
layout = layout.dropInstData();
}
- newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
- layout);
+
+ SmallVector<NamedAttribute> attrs;
+ auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+ if (scatterAttr) {
+ int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+ if (chunkSize > 1) {
+
+ auto chunkSizeAttr = mlir::IntegerAttr::get(
+ mlir::IntegerType::get(ctx, 64), newChunkSize);
+
+ // To create a new attribute with a different chunk_size:
+ auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+ ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+ encoding = newEncoding;
+
+ }
+
+ }
+ newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
+
} else {
newTy = type.clone(tileShape, elemTy);
}
>From 66ab4aa1c4e6b25cb26040a145236217ea0b114f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 15 Jun 2025 16:08:36 +0000
Subject: [PATCH 11/16] add unrolling support for load/store/prefetch/update
with chunk_size
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 131 ++++++++++++------
1 file changed, 90 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8241866c005bb..5d0c1095bef4f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -421,32 +421,28 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
SmallVector<Value> newOps;
if (tdescTy.getRank() == 2) {
- SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
- convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
- convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
- // Assume tdescTy, targetShape, and convertedIndiceVec are defined
- int64_t outerDim = tdescTy.getShape().back();
- int64_t innerDim = targetShape->back();
- int64_t numInnerLoops = outerDim / innerDim;
+ SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+ convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
+ convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
- // Get element size in bytes
- int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
+ int64_t wholeChunk = tdescTy.getShape().back();
+ int64_t blockedChunk = targetShape->back();
+ int64_t numInnerLoops = wholeChunk / blockedChunk;
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
for (int64_t i = 0; i < numInnerLoops; ++i) {
// Compute the offset
- Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
+ Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
- auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
- auto newOp = rewriter.create<xegpu::CreateDescOp>(
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), offsetIndice);
newOps.push_back(newOp);
}
}
- } else if (tdescTy.getRank() == 1) {
+ } else {
convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
for (auto indice : convertedIndiceVec) {
@@ -454,10 +450,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
op.getSource(), indice);
newOps.push_back(newOp);
}
- } else {
- // Unsupported rank for tensor descriptor
- return failure();
- }
+ }
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
@@ -493,10 +486,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (tdescTy.getRank() == 2) {
+ convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+ SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+ int64_t wholeChunk = tdescTy.getShape().back();
+ int64_t blockedChunk = targetShape->back();
+ int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numInnerLoops; ++i) {
+ convertedMasks.push_back(mask);
+ }
+ }
+ if (targetShape && targetShape->size() > 1) {
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+ newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+ }
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -505,9 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}
-
+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
rewriter.replaceOp(op, castOp);
return success();
}
@@ -521,7 +532,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -551,7 +562,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
return failure();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -559,21 +570,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
-
- SmallVector<Type> convertedValTypes =
- getUnrolledTypes(valueTy, *targetShape);
+
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
-
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (tdescTy.getRank() == 2) {
+ convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+ SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+ int64_t wholeChunk = tdescTy.getShape().back();
+ int64_t blockedChunk = targetShape->back();
+ int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numInnerLoops; ++i) {
+ convertedMasks.push_back(mask);
+ }
+ }
+
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -597,7 +627,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
// check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (tdescTy.getRank() > 2)
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -611,17 +641,36 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
VectorType offsetVecTy = offsetVec.getType();
- SmallVector<Type> convertedOffsetTypes =
- getUnrolledTypes(offsetVecTy, *targetShape);
- SmallVector<Value> convertedOffsetVec =
- pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
-
+ SmallVector<Type> convertedOffsetTypes;
+ SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
+
+ if (tdescTy.getRank() == 2) {
+ SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+ convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
+ SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+
+ int64_t wholeChunk = tdescTy.getShape().back();
+ int64_t blockedChunk = targetShape->back();
+ int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+ for (auto offset : convertedOffsetVec1D) {
+ for (int64_t i = 0; i < numInnerLoops; ++i) {
+ convertedOffsetVec.push_back(offset);
+ }
+ }
+
+ } else {
+ convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
+ convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
+
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
newOps.push_back(newOp);
}
+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
return success();
>From 7cbeebba475cdf33f8ba0d4256eeb248f53a056f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 19:30:41 +0000
Subject: [PATCH 12/16] small bug fixes
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 158 ++++++++++++------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 16 +-
2 files changed, 110 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 5d0c1095bef4f..33ebd6219abd7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -267,7 +267,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return success();
}
};
-
+/*
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
@@ -298,6 +298,49 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
return success();
}
};
+*/
+
+struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
+ using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ VectorType valueTy = op.getValueType();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ LDBG("UnrollStoreNdOp: targetShape present? " << (targetShape.has_value() ? "yes" : "no"));
+ if (!targetShape)
+ return failure();
+
+ LDBG("targetShape: ");
+ for (auto v : *targetShape) LDBG(" " << v);
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ LDBG("convertedValTypes size: " << convertedValTypes.size());
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ LDBG("convertedTdescTypes size: " << convertedTdescTypes.size());
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ LDBG("convertedValues size: " << convertedValues.size());
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+ LDBG("convertedTdescs size: " << convertedTdescs.size());
+
+ for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
+ LDBG("Creating StoreNdOp with value: " << v << ", tdesc: " << t);
+ rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ }
+
+ LDBG("Erasing original StoreNdOp: " << op);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
@@ -402,37 +445,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 2)
+ if (!tdescTy.isScattered())
return failure();
-
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
+
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+ if (originalChunkSize > 1)
+ targetIndiceShape.pop_back();
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
- VectorType indiceVecTy = indiceVec.getType();
-
- SmallVector<Type> convertedIndiceTypes;
- SmallVector<Value> convertedIndiceVec;
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, targetIndiceShape);
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
+
SmallVector<Value> newOps;
- if (tdescTy.getRank() == 2) {
- SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
- convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
- convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
-
- int64_t wholeChunk = tdescTy.getShape().back();
- int64_t blockedChunk = targetShape->back();
- int64_t numInnerLoops = wholeChunk / blockedChunk;
+ // more indices is need when chunkSize > 1. Since a big load from one
+ // address could be break into multiple small loads.
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
- for (int64_t i = 0; i < numInnerLoops; ++i) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
// Compute the offset
- Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
+ Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
@@ -443,8 +489,6 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
}
}
} else {
- convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
- convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
op.getSource(), indice);
@@ -468,15 +512,17 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 2)
+ if (!tdescTy.isScattered())
return failure();
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
+
+ SmallVector<int64_t> targetMaskShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -489,25 +535,26 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
- if (tdescTy.getRank() == 2) {
- convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
- SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
- int64_t wholeChunk = tdescTy.getShape().back();
- int64_t blockedChunk = targetShape->back();
- int64_t numInnerLoops = wholeChunk / blockedChunk;
+ if (originalChunkSize > 1) {
+ targetMaskShape.pop_back();
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
for (auto mask : convertedMasks1D) {
- for (int64_t i = 0; i < numInnerLoops; ++i) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
+ // This is to handle the transpose effect when chunkSize > 1.
if (targetShape && targetShape->size() > 1) {
std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
}
} else {
- convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
- convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
}
SmallVector<Value> newOps;
@@ -561,38 +608,38 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 2)
+ if (!tdescTy.isScattered())
return failure();
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
-
+
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
- if (tdescTy.getRank() == 2) {
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
- int64_t wholeChunk = tdescTy.getShape().back();
- int64_t blockedChunk = targetShape->back();
- int64_t numInnerLoops = wholeChunk / blockedChunk;
for (auto mask : convertedMasks1D) {
- for (int64_t i = 0; i < numInnerLoops; ++i) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
-
+ // This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);
} else {
@@ -626,8 +673,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 2)
+ if (tdescTy.getRank() >2)
+ return failure();
+
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -644,18 +693,17 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsetVec;
SmallVector<Value> newOps;
-
- if (tdescTy.getRank() == 2) {
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+ if (originalChunkSize > 1) {
SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
- int64_t wholeChunk = tdescTy.getShape().back();
- int64_t blockedChunk = targetShape->back();
- int64_t numInnerLoops = wholeChunk / blockedChunk;
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize/blockedChunkSize;
for (auto offset : convertedOffsetVec1D) {
- for (int64_t i = 0; i < numInnerLoops; ++i) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
convertedOffsetVec.push_back(offset);
}
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 758e774eb8e01..78e2b7601a1bb 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -105,12 +105,7 @@ struct TestXeGPUUnrollingPatterns
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
-
- int64_t newChunkSize = 0;
- auto instData = layout.getInstData();
- if (!instData.empty())
- newChunkSize = instData.asArrayRef().back();
-
+
if (layout) {
if (layout.getLaneLayout() == nullptr)
layout = xegpu::LayoutAttr();
@@ -118,12 +113,15 @@ struct TestXeGPUUnrollingPatterns
layout = layout.dropInstData();
}
- SmallVector<NamedAttribute> attrs;
- auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
- if (scatterAttr) {
+ if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+ auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
if (chunkSize > 1) {
+ int64_t newChunkSize = chunkSize;
+ auto instData = layout.getInstData();
+ if (!instData.empty())
+ newChunkSize = instData.asArrayRef().back();
auto chunkSizeAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(ctx, 64), newChunkSize);
>From 47fe1438afdb37724b94e54383721dac608a35f5 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 20:56:01 +0000
Subject: [PATCH 13/16] add tests
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 45 +------
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 110 ++++++++++++++++++
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 14 +--
3 files changed, 118 insertions(+), 51 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 33ebd6219abd7..18b6c38b7f5e4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -267,38 +267,6 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return success();
}
};
-/*
-struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
- using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
- LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- VectorType valueTy = op.getValueType();
- xegpu::TensorDescType tdescTy = op.getTensorDescType();
-
- std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
- if (!targetShape)
- return failure();
-
- SmallVector<Type> convertedValTypes =
- getUnrolledTypes(valueTy, *targetShape);
- SmallVector<Type> convertedTdescTypes =
- getUnrolledTypes(tdescTy, *targetShape);
-
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
- SmallVector<Value> convertedTdescs = pack(
- op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
- for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
- rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
-
- rewriter.eraseOp(op);
- return success();
- }
-};
-*/
struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
@@ -309,34 +277,23 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
xegpu::TensorDescType tdescTy = op.getTensorDescType();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
- LDBG("UnrollStoreNdOp: targetShape present? " << (targetShape.has_value() ? "yes" : "no"));
if (!targetShape)
return failure();
- LDBG("targetShape: ");
- for (auto v : *targetShape) LDBG(" " << v);
-
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
- LDBG("convertedValTypes size: " << convertedValTypes.size());
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
- LDBG("convertedTdescTypes size: " << convertedTdescTypes.size());
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
- LDBG("convertedValues size: " << convertedValues.size());
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- LDBG("convertedTdescs size: " << convertedTdescs.size());
- for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
- LDBG("Creating StoreNdOp with value: " << v << ", tdesc: " << t);
+ for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
- }
- LDBG("Erasing original StoreNdOp: " << op);
rewriter.eraseOp(op);
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 52ec3b856da49..0d3bd4fdb311e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -299,4 +299,114 @@ gpu.module @test {
gpu.return
}
+
+//-----
+ // CHECK-LABEL: test_create_tdesc_step_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
+ gpu.func @test_create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+ }
+
+//-----
+ // CHECK-LABEL: test_create_tdesc_step_chunk2
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @test_create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ }
+
+//-----
+ // CHECK-LABEL: test_load_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+
+ gpu.func @test_load_chunk(%src: ui64) -> vector<4x32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+ gpu.return %ld : vector<4x32xf32>
+ }
+
+//-----
+ // CHECK-LABEL: test_store_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+ gpu.func @test_store_chunk(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<4x32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
+
+ gpu.return
+ }
+
+//-----
+ // CHECK-LABEL: test_prefetch_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @test_prefetch_chunk(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+ gpu.return
+ }
+
+//-----
+ // CHECK-LABEL: test_update_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+ gpu.func @test_update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %delta = arith.constant dense<32>: vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+ gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ }
}
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 78e2b7601a1bb..8be37609fc806 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -106,13 +106,6 @@ struct TestXeGPUUnrollingPatterns
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
- if (layout) {
- if (layout.getLaneLayout() == nullptr)
- layout = xegpu::LayoutAttr();
- else
- layout = layout.dropInstData();
- }
-
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
@@ -135,6 +128,13 @@ struct TestXeGPUUnrollingPatterns
}
}
+ if (layout) {
+ if (layout.getLaneLayout() == nullptr)
+ layout = xegpu::LayoutAttr();
+ else
+ layout = layout.dropInstData();
+ }
+
newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
} else {
>From a569bc8acd38337a3f52fbb21ade42e9595e6586 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 21:05:10 +0000
Subject: [PATCH 14/16] clang format fix
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 80 +++++++++++--------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 21 +++--
2 files changed, 55 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 18b6c38b7f5e4..ed8de1639c35e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -407,11 +407,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
if (!tdescTy.isScattered())
return failure();
-
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
-
+
SmallVector<int64_t> targetIndiceShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
// IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
@@ -419,25 +419,28 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
targetIndiceShape.pop_back();
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
- SmallVector<Type> convertedIndiceTypes =
+ SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, targetIndiceShape);
- SmallVector<Value> convertedIndiceVec =
+ SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
-
+
SmallVector<Value> newOps;
// more indices is need when chunkSize > 1. Since a big load from one
// address could be break into multiple small loads.
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
- int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
- for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+ for (auto [indice, indiceType] :
+ llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
// Compute the offset
- Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
+ Value inc = rewriter.create<arith::ConstantIndexOp>(
+ loc, i * blockedChunkSize);
Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
- Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
+ Value offsetIndice =
+ rewriter.create<arith::AddIOp>(loc, indice, incVec);
auto newOp = rewriter.create<xegpu::CreateDescOp>(
loc, newTdescTy, op.getSource(), offsetIndice);
@@ -447,11 +450,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
}
} else {
for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
- op.getSource(), indice);
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(
+ loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
}
- }
+ }
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
@@ -471,11 +474,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
if (!tdescTy.isScattered())
return failure();
-
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
-
+
SmallVector<int64_t> targetMaskShape(*targetShape);
int64_t originalChunkSize = tdescTy.getChunkSize();
@@ -489,29 +492,31 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes;
- SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
if (originalChunkSize > 1) {
targetMaskShape.pop_back();
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
- SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
- int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
- // This is to handle the transpose effect when chunkSize > 1.
+ // This is to handle the transpose effect when chunkSize > 1.
if (targetShape && targetShape->size() > 1) {
std::swap((*targetShape)[0], (*targetShape)[1]);
newValueTy = valueTy.cloneWith(*targetShape, elemTy);
}
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
- convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+ convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
+ loc, rewriter);
}
SmallVector<Value> newOps;
@@ -521,7 +526,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}
-
+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
return success();
@@ -576,38 +581,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
int64_t originalChunkSize = tdescTy.getChunkSize();
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
+
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes;
- SmallVector<Value> convertedMasks;
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
- int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
- SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
for (auto mask : convertedMasks1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
convertedMasks.push_back(mask);
}
}
- // This is to handle the transpose effect when chunkSize > 1.
+ // This is to handle the transpose effect when chunkSize > 1.
std::swap((*targetShape)[0], (*targetShape)[1]);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
- convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
}
SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
for (size_t i = 0; i < convertedValues.size(); ++i) {
Value v = convertedValues[i];
@@ -630,7 +637,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() >2)
+ if (tdescTy.getRank() > 2)
return failure();
if (!tdescTy.isScattered())
@@ -652,12 +659,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
SmallVector<Value> newOps;
int64_t originalChunkSize = tdescTy.getChunkSize();
if (originalChunkSize > 1) {
- SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+ SmallVector<int64_t> shape1D(targetShape->begin(),
+ targetShape->end() - 1);
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
- SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+ SmallVector<Value> convertedOffsetVec1D =
+ pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
int64_t blockedChunkSize = targetShape->back();
- int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto offset : convertedOffsetVec1D) {
for (int64_t i = 0; i < numNewChunks; ++i) {
@@ -667,8 +676,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
} else {
convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
- convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
- }
+ convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+ }
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
auto newOp =
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 8be37609fc806..4fe81a01d3544 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,12 +19,10 @@ using namespace mlir::xegpu;
namespace {
-
#define DEBUG_TYPE "test-xegpu-unroll"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-
struct TestXeGPUUnrollingPatterns
: public PassWrapper<TestXeGPUUnrollingPatterns,
OperationPass<gpu::GPUModuleOp>> {
@@ -60,7 +58,8 @@ struct TestXeGPUUnrollingPatterns
xegpu::TensorDescType tdescTy;
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
tdescTy = createNdOp.getType();
- } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+ } else if (auto updateNdOp =
+ dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
tdescTy = updateNdOp.getTensorDescType();
} else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
tdescTy = prefetchNdOp.getTensorDescType();
@@ -105,28 +104,27 @@ struct TestXeGPUUnrollingPatterns
Attribute encoding = tdescTy.getEncoding();
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
-
+
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
- auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+ auto scatterAttr =
+ mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
-
+
if (chunkSize > 1) {
int64_t newChunkSize = chunkSize;
auto instData = layout.getInstData();
if (!instData.empty())
- newChunkSize = instData.asArrayRef().back();
+ newChunkSize = instData.asArrayRef().back();
auto chunkSizeAttr = mlir::IntegerAttr::get(
- mlir::IntegerType::get(ctx, 64), newChunkSize);
+ mlir::IntegerType::get(ctx, 64), newChunkSize);
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
encoding = newEncoding;
-
}
-
}
if (layout) {
if (layout.getLaneLayout() == nullptr)
@@ -135,7 +133,8 @@ struct TestXeGPUUnrollingPatterns
layout = layout.dropInstData();
}
- newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
+ newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+ layout);
} else {
newTy = type.clone(tileShape, elemTy);
>From 41b839de279e94689fe628e45238ae3e6d56f5a3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 22:34:05 +0000
Subject: [PATCH 15/16] remove 1 complex test
---
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 45 -------------------
1 file changed, 45 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 0d3bd4fdb311e..1726b654c4746 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -255,51 +255,6 @@ gpu.module @test {
gpu.return
}
-//-----
-
- // CHECK-LABEL: test_prefetch_load_store_update
- // CHECK-SAME: [[arg0:%.+]]: ui64
- // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
- // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
- // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-
- gpu.func @test_prefetch_load_store_update(%src: ui64) {
-
- %cst = arith.constant dense<[
- 0, 8, 16, 24, 32, 40, 48, 56,
- 64, 72, 80, 88, 96, 104, 112, 120,
- 128, 136, 144, 152, 160, 168, 176, 184,
- 192, 200, 208, 216, 224, 232, 240, 248
- ]> : vector<32xindex>
-
- %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
- xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-
- %delta = arith.constant dense<[
- 32, 32, 32, 32, 32, 32, 32, 32,
- 32, 32, 32, 32, 32, 32, 32, 64,
- 128, 128, 128, 128, 128, 128, 128, 128,
- 128, 128, 128, 128, 128, 128, 128, 256
- ]> : vector<32xindex>
- %new_tdesc = xegpu.update_offset %tdesc, %delta
- : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
-
- %c17 = arith.constant 17: index
- %mask = vector.create_mask %c17: vector<32xi1>
-
- %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
-
- %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
- xegpu.store %st_vec, %tdesc, %mask:
- vector<32xf32>,
- !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
- vector<32xi1>
-
- gpu.return
- }
-
//-----
// CHECK-LABEL: test_create_tdesc_step_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
>From 3c3023c36e8e73b6eb1b3feac575f302fed6a09c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 17 Jun 2025 21:50:59 +0000
Subject: [PATCH 16/16] address review feedback
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 14 +--
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 109 ++++++++++--------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 31 +----
3 files changed, 73 insertions(+), 81 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index faeec1b27501c..91e0f8196f898 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -426,7 +426,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
SmallVector<Value> newOps;
- // more indices is need when chunkSize > 1. Since a big load from one
+ // More indices is need when chunkSize > 1. Since a big load from one
// address could be break into multiple small loads.
if (originalChunkSize > 1) {
int64_t blockedChunkSize = targetShape->back();
@@ -504,15 +504,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
int64_t numNewChunks = originalChunkSize / blockedChunkSize;
for (auto mask : convertedMasks1D) {
- for (int64_t i = 0; i < numNewChunks; ++i) {
+ for (int64_t i = 0; i < numNewChunks; ++i)
convertedMasks.push_back(mask);
- }
}
// This is to handle the transpose effect when chunkSize > 1.
- if (targetShape && targetShape->size() > 1) {
- std::swap((*targetShape)[0], (*targetShape)[1]);
- newValueTy = valueTy.cloneWith(*targetShape, elemTy);
- }
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+ newValueTy = valueTy.cloneWith(*targetShape, elemTy);
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
@@ -540,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 2)
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 1726b654c4746..41414d802f212 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -2,7 +2,7 @@
gpu.module @test {
- // CHECK-LABEL: test_create_nd_tdesc
+ // CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
@@ -10,31 +10,31 @@ gpu.module @test {
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
- gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
}
//-----
- // CHECK-LABEL: test_create_nd_tdesc_1d
+ // CHECK-LABEL: create_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
- gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
}
//-----
- // CHECK-LABEL: test_update_nd_tdesc
+ // CHECK-LABEL: update_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+ gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -42,11 +42,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_update_nd_tdesc_1d
+ // CHECK-LABEL: update_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
- gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
%update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
@@ -54,11 +54,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch_nd_tdesc
+ // CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+ gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return
@@ -66,23 +66,23 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+ // CHECK-LABEL: prefetch_nd_tdesc_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
- gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+ gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
gpu.return
}
//-----
- // CHECK-LABEL: test_load_nd
+ // CHECK-LABEL: load_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
- gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+ gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
@@ -90,12 +90,12 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_load_nd_1d
+ // CHECK-LABEL: load_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
// CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
- gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+ gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
gpu.return %data : vector<64xf32>
@@ -103,11 +103,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store_nd
+ // CHECK-LABEL: store_nd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+ gpu.func @store_nd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -116,11 +116,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store_nd_1d
+ // CHECK-LABEL: store_nd_1d
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
// CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
- gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+ gpu.func @store_nd_1d(%src: memref<64xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
%data = arith.constant dense<9.0> : vector<64xf32>
xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
@@ -129,7 +129,7 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_createNd_loadNd_storeNd
+ // CHECK-LABEL: createNd_loadNd_storeNd
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
//CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
@@ -137,7 +137,7 @@ gpu.module @test {
//CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
//CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
//CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
- gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+ gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%data = arith.constant dense<9.0> : vector<24x32xf32>
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
@@ -148,23 +148,23 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_dpas
+ // CHECK-LABEL: dpas
// CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
//CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
//CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
//CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
//CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
- gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+ gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
//-----
- // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-LABEL: create_tdesc_vec
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -177,10 +177,10 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-LABEL: create_tdesc_step
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
@@ -190,11 +190,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_load
+ // CHECK-LABEL: load
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
- gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ gpu.func @load(%src: ui64) -> vector<32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -212,11 +212,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_prefetch
+ // CHECK-LABEL: prefetch
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
- gpu.func @test_prefetch(%src: ui64) {
+ gpu.func @prefetch(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
@@ -233,11 +233,11 @@ gpu.module @test {
//-----
- // CHECK-LABEL: test_store
+ // CHECK-LABEL: store
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
- gpu.func @test_store(%src: ui64) {
+ gpu.func @store(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -256,10 +256,10 @@ gpu.module @test {
}
//-----
- // CHECK-LABEL: test_create_tdesc_step_chunk
+ // CHECK-LABEL: create_tdesc_step_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
- gpu.func @test_create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+ gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
@@ -268,10 +268,10 @@ gpu.module @test {
}
//-----
- // CHECK-LABEL: test_create_tdesc_step_chunk2
+ // CHECK-LABEL: create_tdesc_step_chunk2
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- gpu.func @test_create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
%step = arith.constant dense<8> : vector<32xindex>
%seq = vector.step : vector<32xindex>
%cst = arith.muli %seq, %step : vector<32xindex>
@@ -279,13 +279,30 @@ gpu.module @test {
gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
}
+// CHECK-LABEL: create_tdesc_step_chunk3
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+ // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
+ %step = arith.constant dense<8> : vector<16xindex>
+ %seq = vector.step : vector<16xindex>
+ %cst = arith.muli %seq, %step : vector<16xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+ }
+
//-----
- // CHECK-LABEL: test_load_chunk
+ // CHECK-LABEL: load_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.load {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
- gpu.func @test_load_chunk(%src: ui64) -> vector<4x32xf32> {
+ gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -303,11 +320,11 @@ gpu.module @test {
}
//-----
- // CHECK-LABEL: test_store_chunk
+ // CHECK-LABEL: store_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.store {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
- gpu.func @test_store_chunk(%src: ui64) {
+ gpu.func @store_chunk(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -326,11 +343,11 @@ gpu.module @test {
}
//-----
- // CHECK-LABEL: test_prefetch_chunk
+ // CHECK-LABEL: prefetch_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
- gpu.func @test_prefetch_chunk(%src: ui64) {
+ gpu.func @prefetch_chunk(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
@@ -344,11 +361,11 @@ gpu.module @test {
}
//-----
- // CHECK-LABEL: test_update_chunk
+ // CHECK-LABEL: update_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
// CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
- gpu.func @test_update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+ gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 8110ca237dcd3..4400d6d9625f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -87,29 +87,6 @@ struct TestXeGPUUnrollingPatterns
}
}
- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
- xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
- xegpu::TensorDescType tdescTy;
- if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
- tdescTy = createOp.getType();
- } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
- tdescTy = updateOp.getTensorDescType();
- } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
- tdescTy = prefetchOp.getTensorDescType();
- } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
- tdescTy = loadOp.getTensorDescType();
- } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
- tdescTy = storeOp.getTensorDescType();
- }
-
- if (auto layout = tdescTy.getLayoutAttr()) {
- auto inst_data = layout.getInstData();
- if (inst_data && layout.isSgLayout())
- return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
- inst_data.asArrayRef().end());
- }
- }
-
if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};
@@ -128,19 +105,21 @@ struct TestXeGPUUnrollingPatterns
auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
tdescTy.getLayout());
+ // If the encoding is a ScatterTensorDescAttr, we need to
+ // potentially adjust the chunk size based on the inst_data.
if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
auto scatterAttr =
mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
if (chunkSize > 1) {
- int64_t newChunkSize = chunkSize;
+ int64_t blockedChunkSize = chunkSize;
auto instData = layout.getInstData();
if (!instData.empty())
- newChunkSize = instData.asArrayRef().back();
+ blockedChunkSize = instData.asArrayRef().back();
auto chunkSizeAttr = mlir::IntegerAttr::get(
- mlir::IntegerType::get(ctx, 64), newChunkSize);
+ mlir::IntegerType::get(ctx, 64), blockedChunkSize);
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
More information about the Mlir-commits
mailing list