[Mlir-commits] [mlir] 3c52f53 - [MLIR][XeGPU][TransformOps] Add insert_prefetch op (#167356)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 12 02:24:27 PST 2025
Author: Tuomas Kärnä
Date: 2025-11-12T10:24:23Z
New Revision: 3c52f536902b1f4096e25e0e73bc3c26355cbf40
URL: https://github.com/llvm/llvm-project/commit/3c52f536902b1f4096e25e0e73bc3c26355cbf40
DIFF: https://github.com/llvm/llvm-project/commit/3c52f536902b1f4096e25e0e73bc3c26355cbf40.diff
LOG: [MLIR][XeGPU][TransformOps] Add insert_prefetch op (#167356)
Adds `transform.xegpu.insert_prefetch` transform op that inserts
`xegpu.prefetch_nd` ops for the given `Value` in an `scf.for` loop.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index f5e4afad535e5..68a75fdb5b9a5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -200,4 +200,48 @@ def SetGPULaunchThreadsOp
}];
}
+def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
+ let description = [{
+ Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
+ transform finds the corresponding `xegpu.load_nd` op and inserts
+ `xegpu.prefetch_nd` operations for the tile. The load op must reside within
+ the `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
+ argument (default value is 1). Returns a handle to the created
+ `xegpu.create_nd_desc` op.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface:$target,
+ Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
+ DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
+ );
+
+ let results = (outs TransformHandleTypeInterface:$desc_op);
+
+ let assemblyFormat = [{
+ $target
+ `nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
+ attr-dict `:` functional-type(operands, results)
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ OpFoldResult getNbPrefetch() {
+ auto cxt = getContext();
+ if (getDynamicNbPrefetch())
+ return OpFoldResult(getDynamicNbPrefetch());
+ return OpFoldResult(IntegerAttr::get(
+ IntegerType::get(cxt, 64), getStaticNbPrefetch()));
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7a7a8c9066f09..d2235f18ceaec 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -405,6 +406,137 @@ void transform::SetGPULaunchThreadsOp::getEffects(
modifiesPayload(effects);
}
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues))
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ auto value = *targetValues.begin();
+
+ int64_t nbPrefetch = getStaticNbPrefetch();
+ if (getDynamicNbPrefetch()) {
+ // Get dynamic prefetch count from transform param or handle.
+ SmallVector<int32_t> dynamicNbPrefetch;
+ auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+ {getDynamicNbPrefetch()});
+ if (!status.succeeded())
+ return status;
+ if (dynamicNbPrefetch.size() != 1)
+ return emitDefiniteFailure()
+ << "requires exactly one value for dynamic_nb_prefetch";
+ nbPrefetch = dynamicNbPrefetch[0];
+ }
+ if (nbPrefetch <= 0)
+ return emitSilenceableFailure(getLoc())
+ << "nb_prefetch must be a positive integer.";
+
+ // Find load operation of the operand.
+ auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+ if (!maybeLoadOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+ auto loadOp = *maybeLoadOp;
+ if (loadOp.getMixedOffsets().size() == 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op must have offsets.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find the parent scf.for loop.
+ auto forOp = loadOp->getParentOfType<scf::ForOp>();
+ if (!forOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Load op is not contained in a scf.for loop.";
+ diag.attachNote(loadOp.getLoc()) << "load op";
+ return diag;
+ }
+
+ // Find descriptor op.
+ auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+ if (!maybeDescOp)
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ auto descOp = *maybeDescOp;
+ if (descOp.getMixedOffsets().size() > 0) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "desc op with offsets is not supported.";
+ diag.attachNote(descOp.getLoc()) << "desc op";
+ }
+
+ // Clone desc op outside the loop.
+ rewriter.setInsertionPoint(forOp);
+ auto newDescOp =
+ cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+ // Clone reduction loop to emit initial prefetches.
+ // Compute upper bound of the init loop: start + nbPrefetch * step.
+ auto nbPrefetchCst =
+ arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+ auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+ auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+ forOp.getLoc(), forOp.getLowerBound(), nbStep);
+ auto initForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ initUpBound, forOp.getStep());
+
+ auto ctx = rewriter.getContext();
+ auto readCacheHint =
+ xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+ // Modify loadOp mixedOffsets by replacing the for loop induction variable
+ // with the given value.
+ auto getPrefetchOffsets =
+ [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+ IRMapping mapping;
+ mapping.map(forOp.getInductionVar(), replacementVal);
+ SmallVector<Value> dynamicOffsets =
+ llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+ return mapping.lookupOrDefault(v);
+ }));
+ auto constOffsets = loadOp.getConstOffsets().value();
+ return getMixedValues(constOffsets, dynamicOffsets, ctx);
+ };
+
+ // Insert prefetch op in init loop.
+ // Replace induction var with the init loop induction var.
+ rewriter.setInsertionPointToStart(initForOp.getBody());
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(initForOp.getInductionVar()),
+ readCacheHint, readCacheHint, readCacheHint);
+
+ // Insert prefetch op in main loop.
+ // Calculate prefetch offset after the init prefetches have been issued.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+ forOp.getInductionVar(), nbStep);
+ // Replace induction var with correct offset.
+ xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+ newDescOp.getResult(),
+ getPrefetchOffsets(prefetchOffset), readCacheHint,
+ readCacheHint, readCacheHint);
+
+ // Unroll the init loop.
+ if (failed(loopUnrollFull(initForOp)))
+ return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
+
+ results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::InsertPrefetchOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 309883cfc4518..aa3cea58623ea 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -11,6 +11,7 @@
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
MixedValues,
+ MixedInt,
get_op_result_or_value as _get_op_result_or_value,
_dispatch_dynamic_index_list,
)
@@ -134,6 +135,7 @@ def __init__(
)
+ at _ods_cext.register_operation(_Dialect, replace=True)
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
"""Specialization for SetGPULaunchThreadsOp class."""
@@ -168,3 +170,44 @@ def set_gpu_launch_threads(
ip=None,
) -> SetGPULaunchThreadsOp:
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class InsertPrefetchOp(InsertPrefetchOp):
+ """Specialization for InsertPrefetchOp class."""
+
+ def __init__(
+ self,
+ target: Value,
+ *,
+ nb_prefetch: Optional[MixedInt] = 1,
+ loc=None,
+ ip=None,
+ ):
+ static_nb_prefetch = 1
+ dynamic_nb_prefetch = None
+ if isinstance(nb_prefetch, int):
+ static_nb_prefetch = nb_prefetch
+ elif isinstance(nb_prefetch, IntegerAttr):
+ static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error
+ elif isinstance(nb_prefetch, (Operation, Value, OpView)):
+ dynamic_nb_prefetch = nb_prefetch
+
+ super().__init__(
+ transform.AnyOpType.get(),
+ target,
+ dynamic_nb_prefetch=dynamic_nb_prefetch,
+ static_nb_prefetch=static_nb_prefetch,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def insert_prefetch(
+ target: Value,
+ *,
+ nb_prefetch: Optional[MixedInt] = 1,
+ loc=None,
+ ip=None,
+) -> OpResult:
+ return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 24f500658f740..dce4a41982550 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -124,3 +124,34 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_c
+func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ // expected-note at below {{load op}}
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
+ // expected-error at below {{Load op is not contained in a scf.for loop.}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 7f2fbe4271a43..b3b883826c1c8 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -308,3 +308,95 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a
+func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]]
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
+func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ // CHECK: %[[C64:.+]] = arith.constant 64 : index
+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: xegpu.create_nd_tdesc %arg0
+ // CHECK: xegpu.create_nd_tdesc %arg1
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: !xegpu.tensor_desc<256x32xf16
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]]
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ // CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ // CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]]
+ // CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]]
+ %5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %nb = transform.param.constant 2 : i64 -> !transform.param<i64>
+ // CHECK: transform.xegpu.insert_prefetch %{{.*}}
+ %2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
+ transform.apply_patterns to %func {
+ transform.apply_patterns.canonicalization
+ } : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index dc91f5e982579..56c7d71f28431 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import AnyValueType
+from mlir.dialects.transform import structured, AnyValueType
def run(f):
@@ -128,3 +128,68 @@ def setGPULaunchThreadsOp():
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
# CHECK: transform.xegpu.set_gpu_launch_threads
# CHECK: threads = [8, 4, 1]
+
+
+ at run
+def insertPrefetch0():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.insert_prefetch(
+ operand,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetch0
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+
+
+ at run
+def insertPrefetchNbPrefetch():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ xegpu.insert_prefetch(
+ operand,
+ nb_prefetch=2,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetch
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = 2
+
+
+ at run
+def insertPrefetchNbPrefetchParam():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ int32_t = IntegerType.get_signless(32)
+ param_int32_t = transform.ParamType.get(int32_t)
+ nb_param = transform.ParamConstantOp(
+ param_int32_t,
+ IntegerAttr.get(int32_t, 2),
+ )
+ xegpu.insert_prefetch(
+ operand,
+ nb_prefetch=nb_param,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: insertPrefetchNbPrefetchParam
+ # CHECK: %[[OPR:.*]] = get_operand
+ # CHECK: %[[PARAM_OP:.*]] = transform.param.constant 2
+ # CHECK: transform.xegpu.insert_prefetch %[[OPR]]
+ # CHECK-SAME: nb_prefetch = %[[PARAM_OP]]
More information about the Mlir-commits
mailing list