[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add insert_prefetch op (PR #167356)

Tuomas Kärnä llvmlistbot at llvm.org
Tue Nov 11 10:57:00 PST 2025


================
@@ -341,6 +342,143 @@ void transform::SetOpLayoutAttrOp::getEffects(
   modifiesPayload(effects);
 }
 
+DiagnosedSilenceableFailure
+transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
+                                   transform::TransformResults &results,
+                                   transform::TransformState &state) {
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+  auto value = *targetValues.begin();
+
+  int64_t nbPrefetch = getStaticNbPrefetch();
+  if (getDynamicNbPrefetch()) {
+    // Get dynamic prefetch count from transform param or handle.
+    SmallVector<int32_t> dynamicNbPrefetch;
+    auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
+                                          {getDynamicNbPrefetch()});
+    if (!status.succeeded())
+      return status;
+    if (dynamicNbPrefetch.size() != 1) {
+      return emitDefiniteFailure()
+             << "requires exactly one value for dynamic_nb_prefetch";
+    }
+    nbPrefetch = dynamicNbPrefetch[0];
+  }
+  if (nbPrefetch <= 0) {
+    return emitSilenceableFailure(getLoc())
+           << "nb_prefetch must be a positive integer.";
+  }
+
+  // Find load operation of the operand.
+  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+  if (!maybeLoadOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+  }
+  auto loadOp = *maybeLoadOp;
+  if (loadOp.getMixedOffsets().size() == 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op must have offsets.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find the parent scf.for loop.
+  auto forOp = loadOp->getParentOfType<scf::ForOp>();
+  if (!forOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Load op is not contained in a scf.for loop.";
+    diag.attachNote(loadOp.getLoc()) << "load op";
+    return diag;
+  }
+
+  // Find descriptor op.
+  auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+  auto descOp = *maybeDescOp;
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "desc op with offsets is not supported.";
+    diag.attachNote(descOp.getLoc()) << "desc op";
+  }
+
+  // Clone desc op outside the loop.
+  rewriter.setInsertionPoint(forOp);
+  auto newDescOp =
+      cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
+
+  // Clone reduction loop to emit initial prefetches.
+  // Compute upper bound of the init loop: start + nbPrefetch * step.
+  auto nbPrefetchCst =
+      arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
+  auto nbStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), nbPrefetchCst, forOp.getStep());
+  auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
+      forOp.getLoc(), forOp.getLowerBound(), nbStep);
+  auto initForOp =
+      scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+                         initUpBound, forOp.getStep());
+
+  auto ctx = rewriter.getContext();
+  auto readCacheHint =
+      xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
+
+  // Modify loadOp mixedOffsets by replacing the for loop induction variable
+  // with the given value.
+  auto getPrefetchOffsets =
+      [&](Value replacementVal) -> SmallVector<OpFoldResult> {
+    IRMapping mapping;
+    mapping.map(forOp.getInductionVar(), replacementVal);
+    SmallVector<Value> dynamicOffsets =
+        llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
+          return mapping.lookupOrDefault(v);
+        }));
+    auto constOffsets = loadOp.getConstOffsets().value();
+    return getMixedValues(constOffsets, dynamicOffsets, ctx);
+  };
+
+  // Insert prefetch op in init loop.
+  // Replace induction var with the init loop induction var.
+  rewriter.setInsertionPointToStart(initForOp.getBody());
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(initForOp.getInductionVar()),
+                              readCacheHint, readCacheHint, readCacheHint);
+
+  // Insert prefetch op in main loop.
+  // Calculate prefetch offset after the init prefetches have been issued.
+  rewriter.setInsertionPointToStart(forOp.getBody());
+  auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
+                                              forOp.getInductionVar(), nbStep);
+  // Replace induction var with correct offset.
+  xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
+                              newDescOp.getResult(),
+                              getPrefetchOffsets(prefetchOffset), readCacheHint,
+                              readCacheHint, readCacheHint);
+
+  // Unroll the init loop.
+  if (failed(loopUnrollFull(initForOp))) {
----------------
tkarna wrote:

removed

https://github.com/llvm/llvm-project/pull/167356


More information about the Mlir-commits mailing list