[Mlir-commits] [mlir] [mlir][xegpu] Add `XeGPUSgToWiDistributeExperimental` pass. (PR #177492)
Charitha Saumya
llvmlistbot at llvm.org
Wed Jan 28 13:50:11 PST 2026
================
@@ -0,0 +1,561 @@
+//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace {
+
+/// Casts the given vector value `v` to the expected vector type `expectedTy`.
+static Value castValueTo(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> v, VectorType expectedTy) {
+ // If the type matches, simply return the value itself.
+ if (v.getType() == expectedTy)
+ return v;
+ // If only shape differs, use shape cast.
+ if (isa<VectorType>(v.getType()) &&
+ v.getType().getNumElements() == expectedTy.getNumElements())
+ return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
+
+ // Else create an unrealized cast.
+ auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
+ expectedTy, ValueRange{v});
+ return newOp.getResult(0);
+}
+
+/// Checks if all XeGPU anchor ops and vector results have valid layouts.
+static LogicalResult verifyLayouts(Operation *root) {
+ auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult {
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
+ auto layout = anchorOp.getAnchorLayout();
+ if (!layout) {
+ nestedOp->emitError("expected anchor layout attribute on operation");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ }
+ // For each vector result, check if the op contains a result layout
+ // attribute.
+ for (OpResult result : nestedOp->getResults()) {
+ if (isa<VectorType>(result.getType())) {
+ auto layout = xegpu::getDistributeLayoutAttr(result);
+ if (!layout) {
+ nestedOp->emitError(
+ "expected result layout attribute on vector result");
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ return walkResult.wasInterrupted() ? failure() : success();
+}
+
+/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
+/// op. This simply drops the layout attribute from the tensor descriptor type.
+struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::TensorDescType resultType = op.getType();
+ // If no layout, nothing to do.
+ if (!resultType.getLayout())
+ return failure();
+
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
+ op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
+/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
+/// original rank.
+struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor is same as the
+ // anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "xegpu::LoadNdOp require target attribute attached to "
+ "determine transpose "
+ "requirement");
+ auto supportedWiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
+ if (failed(supportedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute the workitem vector type for LoadNdOp");
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
+ op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /**layout**/ nullptr);
+ // Set the packed attribute if the layout requires it.
+ newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
+ // Set the transpose attribute if the layout requires it.
+ if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
+ newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ expectedWiResultTyOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
+/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
+/// incoming value to 1D.
+struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor and value layout is
+ // same as the anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
+ if (valueLayout != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on value and anchor");
+ auto supportedWiValueTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ if (failed(supportedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute wi vector type for StoreNdOp value from tensor "
+ "descriptor");
+
+ xegpu::StoreNdOp::create(
+ rewriter, op.getLoc(),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
+ supportedWiValueTyOrFailure.value()),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
+/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
+/// convert the inputs and output to/from 1D.
+struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
+ // Check if the op has A, B and CD layouts attached.
+ auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
+ auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
+ auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
+ if (!layoutA || !layoutB || !layoutCd)
+ return failure();
+ // llvm::errs() << "tryning to calculate wi types for dpas op\n";
+ auto wiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getType(), layoutCd);
+ auto wiATypeOrFailure =
+ xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
+ auto wiBTypeOrFailure =
+ xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
----------------
charithaintc wrote:
The difference is about XeGPU WI types and vector/arith WI type.
`getDistributedVectorType` is intended for using with xegpu ops to decide its WI level type. This type is 1D.
Example: 32x16xf16 becomes 32xf16
`getDistVecTypeBasedOnLaneLayout `: this is used for everything else (vector, arith etc). This one preserves the rank of the vector and simply divide the input shape by lane layout.
Example: 32x16xf16 becomes 32x1xf16
Rank preservation is needed to make the pattern logic simple for vector ops.
Also note that even for the XeGPU case we insert shape_cast to go back to the 2D shape so it feed 2D types to other ops.
Example:
```
%0 = xegpu.load_nd () -> vector<32xf16>
%1 = shape_cast %0 vector<32xf16> -> vector<32x1xf16>
```
https://github.com/llvm/llvm-project/pull/177492
More information about the Mlir-commits
mailing list