[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)
Jianhui Li
llvmlistbot at llvm.org
Thu Jun 12 20:59:57 PDT 2025
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/143602
>From bac8bc607a0b5f9171472375e5f42c24a5ad429d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 14:40:41 +0000
Subject: [PATCH 1/5] add unroll supportfor reduce and broadcast
---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 6 +
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 110 ++++++++++++++++++
2 files changed, 116 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7cd998eed2e08..a3826c56e1f62 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
return getTileShape(op->getOpResult(0));
+ if (isa<vector::MultiDimReductionOp>(op))
+ return getTileShape(op->getOpOperand(0));
+
+ if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
+ return getTileShape(op->getOpResult(0));
+
return std::nullopt;
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f9114988686c8..8e3673d04eacb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -246,3 +246,113 @@ gpu.module @test_kernel {
gpu.return
}
}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %acc = arith.constant dense<0.0> : vector<64xf32>
+ %c64 = arith.constant 64 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c64 : index
+ %0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32>
+ // CHECK-COUNT-3: vector.multi_reduction <add>, {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32>
+ %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32>
+ %3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+ xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r>
+ gpu.return
+ }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %c1 = arith.constant 1 : index
+ %c32 = arith.constant 32 : index
+ %acc = arith.constant dense<0.0> : vector<32xf32>
+
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+
+ %m = arith.muli %block_id_x, %c32 : index
+ %n = arith.muli %block_id_y, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32>
+
+ // CHECK: vector.multi_reduction <add>, {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32>
+ // CHECK-COUNT-1: vector.multi_reduction <add>, {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32>
+
+ %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32>
+ %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+ xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r>
+ gpu.return
+ }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c64 = arith.constant 64 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c64 : index
+ %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32>
+ // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32>
+ %2 = vector.broadcast %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+ xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l>
+ gpu.return
+ }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c32 = arith.constant 32 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32>
+ %11 = vector.shape_cast %1 : vector<32xf32> to vector<32x1xf32>
+ // CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32>
+ %2 = vector.broadcast %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l>
+ xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l>
+ gpu.return
+ }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 8]>
+#t = #xegpu.layout<inst_data = [8, 16]>
+
+gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+ gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+ %c32 = arith.constant 32 : index
+ %block_id_x = gpu.block_id x
+ %m = arith.muli %block_id_x, %c32 : index
+ %0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l>
+ %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32>
+ // CHECK-COUNT-2: vector.transpose {{.*}} [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+ %2 = vector.transpose %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32>
+ %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t>
+ xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
+ gpu.return
+ }
+}
\ No newline at end of file
>From 30b099ef6d1f8d036290590c73786bd45ea51b57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 23:14:58 +0000
Subject: [PATCH 2/5] add create_desc unrolling and test
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 41 ++++++++++++++++++-
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 24 +++++++++++
2 files changed, 64 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..672b0fb731f31 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
}
};
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+ using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+
+ VectorType indiceVecTy = indiceVec.getType();
+
+ SmallVector<Type> convertedIndiceTypes =
+ getUnrolledTypes(indiceVecTy, *targetShape);
+
+ SmallVector<Value> convertedIndiceVec =
+ pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+
+ for (auto indice : convertedIndiceVec) {
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp>(
patterns.getContext(), options);
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..abdee098ab430 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,30 @@ struct TestXeGPUUnrollingPatterns
}
}
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
+ xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ xegpu::TensorDescType tdescTy;
+ if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+ tdescTy = createOp.getType();
+ } else if (auto updateOp =
+ dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ tdescTy = updateOp.getTensorDescType();
+ } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+ tdescTy = prefetchOp.getTensorDescType();
+ } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+ tdescTy = loadOp.getTensorDescType();
+ } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+ tdescTy = storeOp.getTensorDescType();
+ }
+
+ if (auto layout = tdescTy.getLayoutAttr()) {
+ auto inst_data = layout.getInstData();
+ if (inst_data && layout.isSgLayout())
+ return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+ inst_data.asArrayRef().end());
+ }
+ }
+
if (isa<xegpu::DpasOp>(op))
return SmallVector<int64_t>{8, 16, 16};
>From 8a0e1455683d2a04744bb653348e39c353060704 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:49:39 +0000
Subject: [PATCH 3/5] add unrolling support scatter operations
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 159 +++++++++++++++++-
.../Dialect/XeGPU/xegpu-unroll-patterns.mlir | 143 ++++++++++++++++
2 files changed, 296 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 672b0fb731f31..96b86f6509419 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -409,19 +409,14 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-
VectorType indiceVecTy = indiceVec.getType();
-
SmallVector<Type> convertedIndiceTypes =
getUnrolledTypes(indiceVecTy, *targetShape);
-
SmallVector<Value> convertedIndiceVec =
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
-
for (auto indice : convertedIndiceVec) {
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
newOps.push_back(newOp);
@@ -434,12 +429,164 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
}
};
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+ using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ SmallVector<Value> convertedMasks = pack(
+ op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+ auto newOp =
+ rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,
+ op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ newOps.push_back(newOp);
+ }
+
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+ using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ for (auto t : convertedTdesc)
+ rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+ using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ VectorType maskTy;
+ if (op.getMask())
+ maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedValTypes =
+ getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+
+ SmallVector<Value> convertedValues =
+ pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs =
+ pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> convertedMasks;
+ if (op.getMask()) {
+ SmallVector<Type> convertedMaskTypes =
+ getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
+
+ for (size_t i = 0; i < convertedValues.size(); ++i) {
+ Value v = convertedValues[i];
+ Value t = convertedTdescs[i];
+ Value m = op.getMask() ? convertedMasks[i] : nullptr;
+ rewriter.create<xegpu::StoreScatterOp>(
+ loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+ using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+ LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+ std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+ if (!targetShape)
+ return failure();
+
+ SmallVector<Type> convertedTdescTypes =
+ getUnrolledTypes(tdescTy, *targetShape);
+ SmallVector<Value> convertedTdesc = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+ TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+ VectorType offsetVecTy = offsetVec.getType();
+ SmallVector<Type> convertedOffsetTypes =
+ getUnrolledTypes(offsetVecTy, *targetShape);
+ SmallVector<Value> convertedOffsetVec =
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+
+ SmallVector<Value> newOps;
+ for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+ auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
+ loc, t.getType(), t, o);
+ newOps.push_back(newOp);
+ }
+ Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+ rewriter.replaceOp(op, castOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
- UnrollCreateDescOp>(
+ UnrollCreateDescOp, UnrollLoadGatherOp,
+ UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
patterns.getContext(), options);
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
gpu.return %c : vector<32x32xf32>
}
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_vec
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_create_tdesc_step
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+ %step = arith.constant dense<8> : vector<32xindex>
+ %seq = vector.step : vector<32xindex>
+ %cst = arith.muli %seq, %step : vector<32xindex>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_load
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ gpu.return %ld : vector<32xf32>
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ gpu.func @test_prefetch(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_store
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+ gpu.func @test_store(%src: ui64) {
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+
+ gpu.return
+ }
+
+//-----
+
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
}
>From c91156c7fc789548676f367919ddd577ba06c5ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:51:32 +0000
Subject: [PATCH 4/5] clang format
---
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 31 ++++++++++---------
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 7 ++---
2 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 96b86f6509419..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -418,7 +418,8 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
SmallVector<Value> newOps;
for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+ auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+ op.getSource(), indice);
newOps.push_back(newOp);
}
@@ -454,14 +455,14 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Type> convertedMaskTypes =
getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks = pack(
- op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
- auto newOp =
- rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,
- op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+ auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+ loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}
@@ -520,8 +521,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
- SmallVector<Value> convertedTdescs =
- pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+ SmallVector<Value> convertedTdescs = pack(
+ op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedMasks;
if (op.getMask()) {
@@ -566,12 +567,12 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
SmallVector<Type> convertedOffsetTypes =
getUnrolledTypes(offsetVecTy, *targetShape);
SmallVector<Value> convertedOffsetVec =
- pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+ pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
SmallVector<Value> newOps;
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
- auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
- loc, t.getType(), t, o);
+ auto newOp =
+ rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
newOps.push_back(newOp);
}
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -585,8 +586,8 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
void mlir::xegpu::populateXeGPUUnrollPatterns(
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
- UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
- UnrollCreateDescOp, UnrollLoadGatherOp,
- UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
- patterns.getContext(), options);
+ UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+ UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+ UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+ options);
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index abdee098ab430..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,13 +71,12 @@ struct TestXeGPUUnrollingPatterns
}
}
- if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
- xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+ if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
xegpu::TensorDescType tdescTy;
if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
tdescTy = createOp.getType();
- } else if (auto updateOp =
- dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+ } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
tdescTy = updateOp.getTensorDescType();
} else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
tdescTy = prefetchOp.getTensorDescType();
>From 30cb8d816352b0d29578f408777a151c94b5f3be Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 20:59:47 -0700
Subject: [PATCH 5/5] Update mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
---
mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 47c54bfcb89d0..cca5339b086e4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -249,7 +249,7 @@ gpu.module @test {
%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
- %st_vec = arith.constant dense<1023.>: vector<32xf32>
+ %st_vec = arith.constant dense<1023.0>: vector<32xf32>
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
More information about the Mlir-commits
mailing list