[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution patterns for UpdateNdOffset and PrefetchNd ops. (PR #138033)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu May 8 08:46:31 PDT 2025
================
@@ -1412,6 +1420,151 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Sink an update_nd_offset op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
+/// original op that will not be used by the yield op (and should be cleaned
+/// up later). The yield op will bypass the updateOp's arguments. The tensor
+/// descriptor type is not distributed. Appropriate cast ops are inserted if
+/// the distributed types does not match expected xegpu SIMT types.
+/// Example:
+/// ```
+/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
+/// ...
+/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+/// !xegpu.tensor_desc<4x8xf32, #layout0>
+/// gpu.yield %update
+/// }
+/// ...
+/// ```
+/// To
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
+/// !xegpu.tensor_desc<4x8xf32, #layout0>,
+/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
+/// ...
+/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
+/// gpu.yield %dead, %arg0, %c32, %c16
+/// }
+/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
+/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
+/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
+/// !xegpu.tensor_desc<4x8xf32>
+/// ...
+/// ```
+struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+ auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ // new update op does not have layout attribute.
+ xegpu::TensorDescType newTensorDescTy =
+ updateOp.getTensorDescType().dropLayouts();
+
+ SmallVector<Value, 3> newYieldValues;
+ SmallVector<Type, 3> newYieldTypes;
+ for (Value operand : updateOp->getOperands()) {
+ newYieldValues.push_back(operand);
+ if (isa<xegpu::TensorDescType>(operand.getType())) {
+ newYieldTypes.push_back(newTensorDescTy);
+ } else {
+ newYieldTypes.push_back(operand.getType());
+ }
+ }
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Value> newUpdateOperands;
+ for (size_t i : newRetIndices) {
+ // For the tensor descriptor operand, the layout attribute is dropped
+ // after distribution. Types needs to be resolved in this case.
+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
+ newUpdateOperands.push_back(resolveDistributedTy(
+ newWarpOp.getResult(i), newTensorDescTy, rewriter));
+ } else {
+ newUpdateOperands.push_back(newWarpOp.getResult(i));
+ }
+ }
+ // Create a new update op outside the warp op.
+ auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+ newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
+ removeTemporaryLayoutAttributes(updateOp->getAttrs()));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
+ return success();
+ }
+};
+
+/// Distribute a prefetch_nd op at the end of enclosing
----------------
adam-smnk wrote:
nit: I'd add a comment why the descriptor's shape remains unchanged like in the `CreateNdDescDistribution` just to make it more obvious
https://github.com/llvm/llvm-project/pull/138033
More information about the Mlir-commits
mailing list