[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution support for memref.alloca and xegpu.create_memdesc ops (PR #179018)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 30 22:14:51 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
This PR add distribution support for memref.alloca and xegpu.create_memdesc ops.
It also remove the slice layout requirement for shape_cast.
---
Full diff: https://github.com/llvm/llvm-project/pull/179018.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+61-16)
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+37-8)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a8ed5a289f84a..edead2e3be292 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1611,19 +1611,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
warpOp,
"the source or result of shape_cast op lacks distribution layout");
- // For rank reducing or increasing shape_cast ops, the lower rank layout
- // must be a slice of higher rank layout.
- int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
- int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
- if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
- return rewriter.notifyMatchFailure(
- warpOp, "shape_cast is rank reducing but source layout is not a "
- "slice of result layout");
- if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
- return rewriter.notifyMatchFailure(
- warpOp, "shape_cast is rank increasing but result layout is not a "
- "slice of source layout");
-
FailureOr<VectorType> sourceDistTypeOrFailure =
getDistVecTypeBasedOnLaneLayout(sourceLayout,
shapeCastOp.getSourceVectorType());
@@ -1902,8 +1889,65 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, newWarpOp.getLoc(), extractOp.getType(),
newWarpOp.getResult(newRetIndices[0]));
- Value distributedVal = newWarpOp.getResult(operandIdx);
- rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
+ return success();
+ }
+};
+
+struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid creating multiple copies due to multiple users.
+ return llvm::IsaPred<memref::AllocaOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a memref::Alloca op");
+ auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
+ allocaOp.getType(), nullptr);
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
+ return success();
+ }
+};
+
+struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid creating multiple copies due to multiple users.
+ return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a xegpu::CreateMemDesc op");
+ auto createMemDescOp =
+ operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, createMemDescOp.getSource(),
+ TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
+ rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
+ newWarpOp.getResult(newRetIndices[0]));
+ Value resultVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
return success();
}
};
@@ -2031,7 +2075,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
VectorBitcastDistribution, LoadMatrixDistribution,
StoreMatrixDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution>(
+ MemrefExtractAlignedPointerAsIndexDistribution,
+ MemrefAllocaDistribution, CreateMemDescDistribution>(
patterns.getContext(),
/*pattern benefit=*/PatternHierarchy::Regular);
// For following patterns, we need to override the regular vector distribution
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index a99f850de6175..81f25cc85359f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -486,7 +486,36 @@ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %
gpu.return
}
+// CHECK-LABEL: gpu.func @memref_alloca(
+// CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<2048xi8, 3>
+// CHECK-NEXT: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ALLOCA]] : memref<2048xi8, 3> -> index
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+gpu.func @memref_alloca(%laneid: index) {
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (memref<2048xi8, 3>) {
+ %alloca = memref.alloca() : memref<2048xi8, 3>
+ gpu.yield %alloca : memref<2048xi8, 3>
+ }
+ %ptr = memref.extract_aligned_pointer_as_index %r : memref<2048xi8, 3> -> index
+ %ptr_i64 = arith.index_cast %ptr : index to i64
+ "some_user_op"(%ptr_i64) : (i64) -> ()
+ gpu.return
+}
+// CHECK-LABEL: gpu.func @create_memdesc(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}} : !xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[MDesc:.*]] = xegpu.create_mem_desc %[[W]]#1 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+gpu.func @create_memdesc(%laneid: index, %arg0 : memref<2048xi8, 3>) {
+ %c0 = arith.constant 0 : index
+ %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.mem_desc<4x128xf32>) {
+ %mdesc = xegpu.create_mem_desc %arg0 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+ gpu.yield %mdesc : !xegpu.mem_desc<4x128xf32>
+ }
+ %25 = xegpu.load_matrix %r[%c0, %c0]: !xegpu.mem_desc<4x128xf32>, index, index -> vector<1x16xf32>
+ "some_user_op"(%25) : (vector<1x16xf32>) -> ()
+ gpu.return
+}
// CHECK-LABEL: gpu.func @vector_transpose(
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
@@ -582,16 +611,15 @@ gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
}
-// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
-//
-// CHECK-LABEL: gpu.func @vector_shapecast_unsupported
-// CHECK: %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
-// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
-// CHECK: gpu.yield %[[T1]] : vector<1x16xf32>
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
+// CHECK: %[[T1:.*]] = vector.shape_cast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf32> to vector<1x16xf32>
+// CHECK: gpu.yield %[[T1]], %{{.*}} : vector<1x16xf32>, vector<16xf32>
// CHECK: }
-// CHECK: "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
+// CHECK: %{{.*}} = vector.shape_cast %[[W]]#1 : vector<1xf32> to vector<1x1xf32>
// CHECK: gpu.return
-gpu.func @vector_shapecast_unsupported(%laneid: index) {
+gpu.module @xevm_module{
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout(%laneid: index) {
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
%cst = "some_op"()
{layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
@@ -607,6 +635,7 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
"some_user_op"(%r) : (vector<1x1xf32>) -> ()
gpu.return
}
+}
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
``````````
</details>
https://github.com/llvm/llvm-project/pull/179018
More information about the Mlir-commits
mailing list