[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution patterns for UpdateNdOffset and PrefetchNd ops. (PR #138033)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 5 13:24:26 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Charitha Saumya (charithaintc)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/138033.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+1-1)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+166-2)
- (modified) mlir/test/Dialect/XeGPU/subgroup-distribution.mlir (+66)
- (modified) mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir (+59)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5fa18754305ca..a892f701f724e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -409,7 +409,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}
def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
- [AllTypesMatch<["TensorDesc", "result"]>]> {
+ [Pure, AllTypesMatch<["TensorDesc", "result"]>]> {
let summary = "It updates the offsets for the TensorDesc.";
let description = [{The op updates the offset of the given TensorDesc.
The offsets are relative offset to the current position in the number
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 019032f7743bf..c7128666da7e8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -301,6 +301,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
})
+ .Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
+ visitPrefetchNdOp(prefetchNdOp, operands, results);
+ })
// No need to propagate the layout to operands in CreateNdDescOp because
// they are scalars (offsets, sizes, etc.).
.Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
return success();
}
+void LayoutInfoPropagation::visitPrefetchNdOp(
+ xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // Here we assign the default layout to the tensor descriptor operand of
+ // prefetch.
+ auto tdescTy = prefetch.getTensorDescType();
+ auto prefetchLayout = getDefaultLayoutInfo(
+ VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+ // Propagate the layout to the source tensor descriptor.
+ propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
+}
+
void LayoutInfoPropagation::visitVectorMultiReductionOp(
vector::MultiDimReductionOp reduction,
ArrayRef<LayoutInfoLattice *> operands,
@@ -1173,7 +1192,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
newStoreOperands.push_back(resolveDistributedTy(
newWarpOp.getResult(newRetIndices[0]),
storeNdDistributedValueTyOrFailure.value(), rewriter));
- // For the tensor descriptor operand, the layout attibute is dropped after
+ // For the tensor descriptor operand, the layout attribute is dropped after
// distribution. Types needs to be resolved in this case also.
xegpu::TensorDescType distributedTensorDescTy =
dropLayouts(storeOp.getTensorDescType());
@@ -1412,6 +1431,150 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Sink an update_nd_offset op feeding into yield op of an enclosing
+/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
+/// original op that will not be used by the yield op (and should be cleaned
+/// up later). The yield op will bypass the updateOp's arguments. The tensor
+/// descriptor type is not distributed. Appropriate cast ops are inserted if
+/// the distributed types does not match expected xegpu SIMT types.
+/// Example:
+/// ```
+/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
+/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
+/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
+/// ...
+/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+/// !xegpu.tensor_desc<4x8xf32, #lo0>
+/// gpu.yield %update
+/// }
+/// ...
+/// ```
+/// To
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
+/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
+/// ...
+/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
+/// !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
+/// gup.yield %dead, %arg0, %c32, %c16
+/// }
+/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
+/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
+/// %1 = xegpu.update_nd_offset %0, [%c32, %c16]:
+/// !xegpu.tensor_desc<4x8xf32>
+/// ...
+/// ```
+struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
+ auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ // new update op does not have layout attribute.
+ xegpu::TensorDescType newTensorDescTy =
+ dropLayouts(updateOp.getTensorDescType());
+
+ SmallVector<Value, 3> newYieldValues;
+ SmallVector<Type, 3> newYieldTypes;
+ for (Value operand : updateOp->getOperands()) {
+ newYieldValues.push_back(operand);
+ if (isa<xegpu::TensorDescType>(operand.getType())) {
+ newYieldTypes.push_back(newTensorDescTy);
+ } else {
+ newYieldTypes.push_back(operand.getType());
+ }
+ }
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Value> newUpdateOperands;
+ for (size_t i : newRetIndices) {
+ // For the tensor descriptor operand, the layout attribute is dropped
+ // after distribution. Types needs to be resolved in this case.
+ if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
+ newUpdateOperands.push_back(resolveDistributedTy(
+ newWarpOp.getResult(i), newTensorDescTy, rewriter));
+ } else {
+ newUpdateOperands.push_back(newWarpOp.getResult(i));
+ }
+ }
+ // Create a new update op outside the warp op.
+ auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+ newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
+ removeTemporaryLayoutAttributes(updateOp->getAttrs()));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
+ return success();
+ }
+};
+
+/// Distribute a prefetch_nd op at the end of enclosing
+/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
+/// through the warp op interface they would be propagated as returned values.
+/// Appropriate cast ops are inserted if the distributed types does not match
+/// expected xegpu SIMT types.
+///
+/// Example:
+///
+/// ```
+/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
+/// gpu.warp_execute_on_lane_0(%laneid) -> () {
+/// ...
+/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #lo0>
+/// }
+/// ```
+/// To
+/// ```
+/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
+// !xegpu.tensor_desc<4x8xf32, #lo0>) {
+/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #lo0>
+/// }
+/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
+/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
+/// xegpu.prefetch_nd %0 : !xegpu.tensor_desc<4x8xf32>
+///
+/// ```
+struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
+ PatternRewriter &rewriter) const override {
+ auto yield = cast<gpu::YieldOp>(
+ subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Operation *lastNode = yield->getPrevNode();
+ auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
+ if (!prefetchOp)
+ return failure();
+ xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ prefetchOp, "the source tensor descriptor lacks layout attribute");
+
+ SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
+ SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
+ // Create a new prefetch op outside the warp op with updated tensor
+ // descriptor type. Source tensor descriptor require type resolution.
+ xegpu::TensorDescType newTensorDescTy =
+ dropLayouts(prefetchOp.getTensorDescType());
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
+ newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
+ rewriter.create<xegpu::PrefetchNdOp>(
+ newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
+ removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
+ rewriter.eraseOp(prefetchOp);
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1430,7 +1593,8 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution>(patterns.getContext());
+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
+ UpdateNdOffsetDistribution>(patterns.getContext());
}
void XeGPUSubgroupDistributePass::runOnOperation() {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
index f8f2cd55c28d0..1df0520980766 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribution.mlir
@@ -160,3 +160,69 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @update_nd_offset_1d(
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
+gpu.module @test {
+gpu.func @update_nd_offset_1d(%arg0: memref<256xf32>){
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+ %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
+ xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @update_nd_offset_2d
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[T1]] : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
+gpu.module @test {
+gpu.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+ %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+ xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_2d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
+gpu.module @test {
+gpu.func @prefetch_2d(%arg0: memref<256x256xf16>){
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+ xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @prefetch_1d
+// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
+gpu.module @test {
+gpu.func @prefetch_1d(%arg0: memref<256xf16>){
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+ xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
+ gpu.return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
index a5468681e68dc..c7c82fc8dbb3c 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir
@@ -561,3 +561,62 @@ func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.t
xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
return
}
+
+// -----
+// CHECK: function: update_nd_offset_1d:
+// CHECK: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+func.func @update_nd_offset_1d(%arg0: memref<256xf32>){
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %1 = arith.constant dense<1.000000e+00> : vector<16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
+ %2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
+ xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
+ return
+}
+
+// -----
+// CHECK: function: update_nd_offset_2d:
+// CHECK: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+// CHECK-NEXT: op : %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%{{.*}}] : !xegpu.tensor_desc<16x16xf32>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+func.func @update_nd_offset_2d(%arg0: memref<256x256xf32>){
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
+ %2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
+ xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
+ return
+}
+
+// -----
+// CHECK: function: prefetch_2d:
+// CHECK: layout for result #0: Not assigned.
+// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+// CHECK-NEXT: layout for result #0: lane_layout: [1, 16], lane_data: [1, 1]
+func.func @prefetch_2d(%arg0: memref<256x256xf16>){
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+ xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
+ return
+}
+
+// -----
+// CHECK: function: prefetch_1d:
+// CHECK: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+// CHECK-NEXT: layout for result #0: lane_layout: [16], lane_data: [1]
+func.func @prefetch_1d(%arg0: memref<256xf16>){
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
+ xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/138033
More information about the Mlir-commits
mailing list