[Mlir-commits] [mlir] 118bfcd - [MLIR][XEGPU] Add blocking support for scatter ops (#144766)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 14:52:06 PDT 2025
Author: Jianhui Li
Date: 2025-06-18T14:52:03-07:00
New Revision: 118bfcda46c17349575217bc901e8e5942521955
URL: https://github.com/llvm/llvm-project/commit/118bfcda46c17349575217bc901e8e5942521955
DIFF: https://github.com/llvm/llvm-project/commit/118bfcda46c17349575217bc901e8e5942521955.diff
LOG: [MLIR][XEGPU] Add blocking support for scatter ops (#144766)
Add blocking support for scatter ops: Create_tdesc, update, prefetch,
load and store. It also enables the load/store with chunk size.
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index a3826c56e1f62..3950e8f70d1ca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -134,11 +134,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
std::optional<SmallVector<int64_t>>
XeGPUBlockingPass::getTileShape(Operation *op) const {
- if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
+ xegpu::UpdateOffsetOp>(op))
return getTileShape(op->getOpResult(0));
- if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+ if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
+ xegpu::LoadGatherOp>(op))
return getTileShape(op->getOpOperand(0));
- if (isa<xegpu::StoreNdOp>(op))
+ if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
return getTileShape(op->getOpOperand(1));
if (isa<xegpu::DpasOp>(op)) {
@@ -295,12 +297,36 @@ void XeGPUBlockingPass::runOnOperation() {
Type elemTy = type.getElementType();
Type newTy;
- if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
- newTy = xegpu::TensorDescType::get(
- ctx, tileShape, elemTy, tdescTy.getEncoding(),
- tdescTy.getLayoutAttr().dropInstData());
- else
+ if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+
+ Attribute encoding = tdescTy.getEncoding();
+ // If the encoding is a ScatterTensorDescAttr, we need to
+ // potentially adjust the chunk size based on the inst_data.
+ if (tdescTy.isScattered()) {
+ auto scatterAttr =
+ llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
+ int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+ if (chunkSize > 1) {
+ int64_t blockedChunkSize = chunkSize;
+ auto instData = tdescTy.getLayoutAttr().getInstData();
+ if (!instData.empty())
+ blockedChunkSize = instData.asArrayRef().back();
+
+ // To create a new attribute with a
diff erent chunk_size:
+ auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+ ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+
+ encoding = newEncoding;
+ }
+ }
+
+ newTy =
+ xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+ tdescTy.getLayoutAttr().dropInstData());
+ } else {
newTy = type.clone(tileShape, elemTy);
+ }
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 67d3bd9b393c0..f977ba3c11bcf 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -250,8 +250,7 @@ gpu.module @test_kernel {
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%acc = arith.constant dense<0.0> : vector<64xf32>
%c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 16]>
#r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#r = #xegpu.layout<inst_data = [16]>
#l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
// -----
#l = #xegpu.layout<inst_data = [16, 8]>
#t = #xegpu.layout<inst_data = [8, 16]>
-
-gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel {
gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%c32 = arith.constant 32 : index
@@ -355,4 +350,100 @@ gpu.module @kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
gpu.return
}
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test_kernel {
+ // CHECK-LABEL: test_prefetch_load_store_update
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+ // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+ // CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+ // CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+ xegpu.store %st_vec, %tdesc, %mask:
+ vector<32xf32>,
+ !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
+
+}
+
+// -----
+
+gpu.module @test_kernel {
+ // CHECK-LABEL: test_prefetch_load_store_update_chunk
+ // CHECK-SAME: [[arg0:%.+]]: ui64
+ // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+ // CHECK-COUNT-4: xegpu.load {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+ // CHECK-COUNT-4: xegpu.store {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+
+ gpu.func @test_prefetch_load_store_update_chunk(%src: ui64) {
+
+ %cst = arith.constant dense<[
+ 0, 8, 16, 24, 32, 40, 48, 56,
+ 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184,
+ 192, 200, 208, 216, 224, 232, 240, 248
+ ]> : vector<32xindex>
+
+ %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+ xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+ %delta = arith.constant dense<[
+ 32, 32, 32, 32, 32, 32, 32, 32,
+ 32, 32, 32, 32, 32, 32, 32, 64,
+ 128, 128, 128, 128, 128, 128, 128, 128,
+ 128, 128, 128, 128, 128, 128, 128, 256
+ ]> : vector<32xindex>
+ %new_tdesc = xegpu.update_offset %tdesc, %delta
+ : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+ %c17 = arith.constant 17: index
+ %mask = vector.create_mask %c17: vector<32xi1>
+
+ %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+ %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
+ xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>:
+ vector<4x32xf32>,
+ !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>,
+ vector<32xi1>
+
+ gpu.return
+ }
+}
+
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4400d6d9625f7..c84eb74198544 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -102,14 +102,14 @@ struct TestXeGPUUnrollingPatterns
// attribute
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
Attribute encoding = tdescTy.getEncoding();
- auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
- tdescTy.getLayout());
+ auto layout = tdescTy.getLayoutAttr();
// If the encoding is a ScatterTensorDescAttr, we need to
// potentially adjust the chunk size based on the inst_data.
- if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+ if (tdescTy.isScattered()) {
auto scatterAttr =
- mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+ llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
+ encoding);
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
if (chunkSize > 1) {
@@ -118,12 +118,10 @@ struct TestXeGPUUnrollingPatterns
if (!instData.empty())
blockedChunkSize = instData.asArrayRef().back();
- auto chunkSizeAttr = mlir::IntegerAttr::get(
- mlir::IntegerType::get(ctx, 64), blockedChunkSize);
-
// To create a new attribute with a
diff erent chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
- ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+ ctx, scatterAttr.getMemorySpace().getValue(),
+ blockedChunkSize);
encoding = newEncoding;
}
More information about the Mlir-commits
mailing list