[Mlir-commits] [mlir] [MLIR] [XeGPU] Add distribution support for memref.alloca and xegpu.create_memdesc ops (PR #179018)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 30 22:14:52 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR add distribution support for memref.alloca and xegpu.create_memdesc ops. 
It also remove the slice layout requirement for shape_cast. 

---
Full diff: https://github.com/llvm/llvm-project/pull/179018.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+61-16) 
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir (+37-8) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index a8ed5a289f84a..edead2e3be292 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1611,19 +1611,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
           warpOp,
           "the source or result of shape_cast op lacks distribution layout");
 
-    // For rank reducing or increasing shape_cast ops, the lower rank layout
-    // must be a slice of higher rank layout.
-    int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
-    int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
-    if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
-      return rewriter.notifyMatchFailure(
-          warpOp, "shape_cast is rank reducing but source layout is not a "
-                  "slice of result layout");
-    if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
-      return rewriter.notifyMatchFailure(
-          warpOp, "shape_cast is rank increasing but result layout is not a "
-                  "slice of source layout");
-
     FailureOr<VectorType> sourceDistTypeOrFailure =
         getDistVecTypeBasedOnLaneLayout(sourceLayout,
                                         shapeCastOp.getSourceVectorType());
@@ -1902,8 +1889,65 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
     auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
         rewriter, newWarpOp.getLoc(), extractOp.getType(),
         newWarpOp.getResult(newRetIndices[0]));
-    Value distributedVal = newWarpOp.getResult(operandIdx);
-    rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
+    return success();
+  }
+};
+
+struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      // Check if the yield operand that was produced by the *last* scattered
+      // load op to avoid creating multiple copies due to multiple users.
+      return llvm::IsaPred<memref::AllocaOp>(op) &&
+             warpOp.getTerminator()->getPrevNode() == op;
+    });
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a memref::Alloca op");
+    auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
+                                                allocaOp.getType(), nullptr);
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
+    return success();
+  }
+};
+
+struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      // Check if the yield operand that was produced by the *last* scattered
+      // load op to avoid creating multiple copies due to multiple users.
+      return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
+             warpOp.getTerminator()->getPrevNode() == op;
+    });
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a xegpu::CreateMemDesc op");
+    auto createMemDescOp =
+        operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, createMemDescOp.getSource(),
+        TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
+        rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
+        newWarpOp.getResult(newRetIndices[0]));
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
     return success();
   }
 };
@@ -2031,7 +2075,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                LoadDistribution, StoreDistribution, VectorTransposeDistribution,
                VectorBitcastDistribution, LoadMatrixDistribution,
                StoreMatrixDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution>(
+               MemrefExtractAlignedPointerAsIndexDistribution,
+               MemrefAllocaDistribution, CreateMemDescDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/PatternHierarchy::Regular);
   // For following patterns, we need to override the regular vector distribution
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index a99f850de6175..81f25cc85359f 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -486,7 +486,36 @@ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @memref_alloca(
+// CHECK-NEXT:    %[[ALLOCA:.*]] = memref.alloca() : memref<2048xi8, 3>
+// CHECK-NEXT:    %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ALLOCA]] : memref<2048xi8, 3> -> index
+// CHECK-NEXT:    %[[CAST:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+gpu.func @memref_alloca(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (memref<2048xi8, 3>) {
+    %alloca = memref.alloca() : memref<2048xi8, 3>
+    gpu.yield %alloca :  memref<2048xi8, 3>
+  }
+  %ptr = memref.extract_aligned_pointer_as_index %r : memref<2048xi8, 3> -> index
+  %ptr_i64 = arith.index_cast %ptr : index to i64
+  "some_user_op"(%ptr_i64) : (i64) -> ()
+  gpu.return
+}
 
+// CHECK-LABEL: gpu.func @create_memdesc(
+// CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>) {
+// CHECK:         gpu.yield %{{.*}}, %{{.*}} : !xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[MDesc:.*]] = xegpu.create_mem_desc %[[W]]#1 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+gpu.func @create_memdesc(%laneid: index, %arg0 : memref<2048xi8, 3>) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.mem_desc<4x128xf32>) {
+    %mdesc = xegpu.create_mem_desc %arg0 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+    gpu.yield %mdesc :  !xegpu.mem_desc<4x128xf32>
+  }
+  %25 = xegpu.load_matrix  %r[%c0, %c0]: !xegpu.mem_desc<4x128xf32>, index, index -> vector<1x16xf32>
+  "some_user_op"(%25) : (vector<1x16xf32>) -> ()
+  gpu.return
+}
 
 // CHECK-LABEL: gpu.func @vector_transpose(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {
@@ -582,16 +611,15 @@ gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
 }
 
 
-// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
-//
-// CHECK-LABEL:  gpu.func @vector_shapecast_unsupported
-// CHECK:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
-// CHECK:            %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
-// CHECK:            gpu.yield %[[T1]] : vector<1x16xf32>
+// CHECK-LABEL:  gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
+// CHECK:            %[[T1:.*]] = vector.shape_cast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf32> to vector<1x16xf32>
+// CHECK:            gpu.yield %[[T1]], %{{.*}} : vector<1x16xf32>, vector<16xf32>
 // CHECK:          }
-// CHECK:          "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
+// CHECK:          %{{.*}} = vector.shape_cast %[[W]]#1 : vector<1xf32> to vector<1x1xf32>
 // CHECK:          gpu.return
-gpu.func @vector_shapecast_unsupported(%laneid: index) {
+gpu.module @xevm_module{
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout(%laneid: index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
     %cst = "some_op"()
       {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
@@ -607,6 +635,7 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
   "some_user_op"(%r) : (vector<1x1xf32>) -> ()
   gpu.return
 }
+}
 
 
 // CHECK-LABEL:  gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted

``````````

</details>


https://github.com/llvm/llvm-project/pull/179018


More information about the Mlir-commits mailing list