[Mlir-commits] [mlir] d9c65f9 - [mlir][xegpu] Add `XeGPUSgToWiDistributeExperimental` pass. (#177492)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 29 09:57:05 PST 2026
Author: Charitha Saumya
Date: 2026-01-29T09:57:01-08:00
New Revision: d9c65f94b11f2fed7b78af3a8e7bec6c84bdf219
URL: https://github.com/llvm/llvm-project/commit/d9c65f94b11f2fed7b78af3a8e7bec6c84bdf219
DIFF: https://github.com/llvm/llvm-project/commit/d9c65f94b11f2fed7b78af3a8e7bec6c84bdf219.diff
LOG: [mlir][xegpu] Add `XeGPUSgToWiDistributeExperimental` pass. (#177492)
Currently XeGPU lowering pipeline uses `XeGPUSubgroupDistribute` pass to
subgroup to work item distribution of ops. This pass is well established
and relies on vector distribution's `WarpOp` based distribution
mechanism. However, recent experiments with larger kernels have shown
that this pass is very expensive in terms of compile time (see below).
This prompted us to create a new pass that does not rely on `WarpOp`
based distribution. This PR adds the initial infra to move away from the
old way and align Wg To WI distribution with Wg to Sg distribution. New
pass also uses context-aware type conversion based on XeGPU layouts to
distributed vector types from SG to WI.
This PR adds the following changes:
* SG to WI distribution pass based on context-aware type conversions
using `OpConversionPatterns`
* Test pass for testing individual patterns
(`TestXeGPUSgToWiDistributeExperimental`)
* `XeGPUSgToWiDistributeExperimentalPass` which will eventually replace
`XeGPUSubgroupDistribute`
Flash attention e2e compilations stats:
```
----Wall Time---- ----Name----
0.0032 ( 0.2%) Parser
0.0008 ( 0.0%) CSE
0.0000 ( 0.0%) (A) DominanceInfo
0.0002 ( 0.0%) GpuXeVMAttachTarget
1.1427 ( 58.7%) 'gpu.module' Pipeline
0.0019 ( 0.1%) XeGPUWgToSgDistribute
0.0003 ( 0.0%) CSE
0.0000 ( 0.0%) (A) DominanceInfo
0.0002 ( 0.0%) LowerAffinePass
0.0001 ( 0.0%) CSE
0.0000 ( 0.0%) (A) DominanceInfo
0.0008 ( 0.0%) XeGPUPropagateLayout
0.0056 ( 0.3%) XeGPUBlocking
0.0010 ( 0.1%) Canonicalizer
0.0004 ( 0.0%) CSE
0.0000 ( 0.0%) (A) DominanceInfo
0.0015 ( 0.1%) XeGPUPropagateLayout
0.0007 ( 0.0%) XeGPUOptimizeBlockLoads
0.0010 ( 0.0%) Canonicalizer
0.0004 ( 0.0%) CSE
0.0000 ( 0.0%) (A) DominanceInfo
0.0015 ( 0.1%) XeGPUPropagateLayout
1.1274 ( 57.9%) XeGPUSubgroupDistribute
0.7959 ( 40.9%) Output
0.0022 ( 0.1%) Rest
1.9461 (100.0%) Total
```
Added:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index e25adbd1673d9..cb71f19da62f0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -114,4 +114,14 @@ def XeGPUPeepHoleOptimizer : Pass<"xegpu-optimize-peephole"> {
"vector::VectorDialect"];
}
+def XeGPUSgToWiDistributeExperimental : Pass<"xegpu-sg-to-wi-distribute-experimental"> {
+ let summary = "Distribute XeGPU ops to work items";
+ let description = [{
+ The pass distributes subgroup level XeGPU ops to work item level XeGPU ops.
+ }];
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+ "vector::VectorDialect", "index::IndexDialect"];
+}
+
+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 9628d3064eabf..fede329990be4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -11,6 +11,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
@@ -71,6 +72,16 @@ void populateXeGPUMoveFuncBodyToWarpOpPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU workgroup to subgroup distribution into
/// `patterns`.
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
+/// Define only the type conversions needed for XeGPU subgroup to workitem
+/// distribution.
+void populateXeGPUSgToWiDistributeTypeConversions(TypeConverter &typeConverter);
+/// Defines type conversions and legality for XeGPU subgroup to workitem
+/// distribution and appends the required conversion patterns into `patterns`.
+/// Appends patterns for XeGPU subgroup to workitem distribution into
+/// `patterns`.
+void populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target);
/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
/// Users can control whether an operation to be unrolled or not, as well as
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 6573343a8bc97..700db5f9dd9be 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -26,6 +26,10 @@ namespace xegpu {
class DistributeLayoutAttr;
class LayoutAttr;
class TensorDescType;
+
+namespace uArch {
+struct uArch;
+} // namespace uArch
} // namespace xegpu
namespace xegpu {
@@ -63,6 +67,21 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
LayoutAttr layout);
+/// Helper function to get distributed vector type for a source vector type
+/// according to the lane_layout. We simply divide each dimension of tensor
+/// descriptor shape by corresponding lane_layout dimension. If
+/// array_length > 1, that is appended to the front of the distributed shape.
+///
+/// Examples:
+/// | original vector shape | lane_layout | distributed vector shape |
+/// |-----------------------|-------------|--------------------------|
+/// | 32x16 | [1, 16] | 32x1 |
+/// | 32x16 | [2, 8] | 16x2 |
+/// | 2x32x16 | [1, 16] | 2x32x1 |
+FailureOr<VectorType>
+getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout,
+ VectorType originalType);
+
/// Extract a set of small vectors from a value with a given shape using
/// vector.extract_stride_slice
SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,
@@ -190,6 +209,14 @@ void recoverTemporaryLayoutsDeprecated(Operation *op);
/// a layout attribute.
bool recoverTemporaryLayouts(Operation *rootOp);
+/// Helper function to check if the layout is packed. Layout is packed if it is
+/// 2D and lane_data[0] != 1 (data packed from col dimension).
+/// TODO: Move to target info.
+bool requirePacked(const LayoutAttr layout);
+
+/// Helper function to check if the layout requires a transpose effect.
+bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 15d31eadcb6df..47a3f371164fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUBlocking.cpp
XeGPUFoldAliasOps.cpp
+ XeGPUSgToWiDistributeExperimental.cpp
XeGPUSubgroupDistribute.cpp
XeGPUUnroll.cpp
XeGPUWgToSgDistribute.cpp
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
new file mode 100644
index 0000000000000..4ae858363d5b6
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -0,0 +1,561 @@
+//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace {
+
+/// Casts the given vector value `v` to the expected vector type `expectedTy`.
+static Value castValueTo(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> v, VectorType expectedTy) {
+ // If the type matches, simply return the value itself.
+ if (v.getType() == expectedTy)
+ return v;
+ // If only shape
diff ers, use shape cast.
+ if (isa<VectorType>(v.getType()) &&
+ v.getType().getNumElements() == expectedTy.getNumElements())
+ return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
+
+ // Else create an unrealized cast.
+ auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
+ expectedTy, ValueRange{v});
+ return newOp.getResult(0);
+}
+
+/// Checks if all XeGPU anchor ops and vector results have valid layouts.
+static LogicalResult verifyLayouts(Operation *root) {
+ auto walkResult = root->walk([&](Operation *nestedOp) -> WalkResult {
+ if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(nestedOp)) {
+ auto layout = anchorOp.getAnchorLayout();
+ if (!layout) {
+ nestedOp->emitError("expected anchor layout attribute on operation");
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ }
+ // For each vector result, check if the op contains a result layout
+ // attribute.
+ for (OpResult result : nestedOp->getResults()) {
+ if (isa<VectorType>(result.getType())) {
+ auto layout = xegpu::getDistributeLayoutAttr(result);
+ if (!layout) {
+ nestedOp->emitError(
+ "expected result layout attribute on vector result");
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ return walkResult.wasInterrupted() ? failure() : success();
+}
+
+/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
+/// op. This simply drops the layout attribute from the tensor descriptor type.
+struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::TensorDescType resultType = op.getType();
+ // If no layout, nothing to do.
+ if (!resultType.getLayout())
+ return failure();
+
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
+ op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
+/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
+/// original rank.
+struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor is same as the
+ // anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ op, "xegpu::LoadNdOp require target attribute attached to "
+ "determine transpose "
+ "requirement");
+ auto supportedWiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
+ if (failed(supportedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute the workitem vector type for LoadNdOp");
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+ auto newOp = xegpu::LoadNdOp::create(
+ rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
+ op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr(), /**layout**/ nullptr);
+ // Set the packed attribute if the layout requires it.
+ newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
+ // Set the transpose attribute if the layout requires it.
+ if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
+ newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ expectedWiResultTyOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
+/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
+/// incoming value to 1D.
+struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+ // Check if the layout attached to the tensor descriptor and value layout is
+ // same as the anchor layout. Otherwise, this is a conflict.
+ if (op.getTensorDescType().getLayout() != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on tensor descriptor and anchor");
+ auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
+ if (valueLayout != layout)
+ return rewriter.notifyMatchFailure(
+ op, "conflicting layout attributes on value and anchor");
+ auto supportedWiValueTyOrFailure =
+ xegpu::getDistributedVectorType(op.getTensorDescType());
+ if (failed(supportedWiValueTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute wi vector type for StoreNdOp value from tensor "
+ "descriptor");
+
+ xegpu::StoreNdOp::create(
+ rewriter, op.getLoc(),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
+ supportedWiValueTyOrFailure.value()),
+ adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
+/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
+/// convert the inputs and output to/from 1D.
+struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
+ using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
+ // Check if the op has A, B and CD layouts attached.
+ auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
+ auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
+ auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
+ if (!layoutA || !layoutB || !layoutCd)
+ return failure();
+ // llvm::errs() << "tryning to calculate wi types for dpas op\n";
+ auto wiResultTyOrFailure =
+ xegpu::getDistributedVectorType(op.getType(), layoutCd);
+ auto wiATypeOrFailure =
+ xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
+ auto wiBTypeOrFailure =
+ xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
+ if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
+ failed(wiBTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "failed to calculate supported workitem vector types for DpasOp "
+ "from layouts");
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute expected workitem vector type for DpasOp from "
+ "lane layout");
+ auto newOp = xegpu::DpasOp::create(
+ rewriter, op->getLoc(), wiResultTyOrFailure.value(),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
+ wiATypeOrFailure.value()),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
+ wiBTypeOrFailure.value()),
+ castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
+ wiResultTyOrFailure.value()),
+ /** layoutA**/ nullptr,
+ /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
+ // Explicitly set the new types to enable correct type materializations.
+ rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+ expectedWiResultTyOrFailure.value()));
+ return success();
+ }
+};
+
+/// Distributes elementwise ops to workitem-level elementwise ops. This
+/// currently handles elementwise ops with single result only.
+struct SgToWiElementWise : public ConversionPattern {
+ SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
+ : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only match ops with elementwise trait and single result.
+ if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+ return failure();
+
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(
+ op, "operation result is not a vector type");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType newResultType = wiShapeOrFailure.value();
+ OperationState state(op->getLoc(), op->getName());
+ state.addOperands(operands);
+ state.addTypes(newResultType);
+ // Copy all attributes except for DistributeLayoutAttr.
+ for (auto attr : op->getAttrs()) {
+ if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
+ state.addAttribute(attr.getName(), attr.getValue());
+ }
+ Operation *newOp = rewriter.create(state);
+
+ rewriter.replaceOp(op, newOp->getResult(0));
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
+/// ConstantOp.
+struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = dyn_cast<VectorType>(op.getType());
+ if (!resultType)
+ return failure();
+
+ // Only handle dense vector constants
+ auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
+ if (!dense)
+ return rewriter.notifyMatchFailure(
+ op, "only dense splat vector constants are supported");
+
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
+ if (!layout || !layout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "operation result does not have subgroup distribute layout");
+
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+
+ VectorType newResultType = wiShapeOrFailure.value();
+ auto sclarValue = dense.getSplatValue<Attribute>();
+ auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
+
+ auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
+ newDenseAttr);
+ rewriter.replaceOp(op, newOp.getResult());
+ return success();
+ }
+};
+
+/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
+struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ // If no layout, nothing to do.
+ if (!layout)
+ return failure();
+
+ xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
+ op.getMixedOffsets(), op.getL1HintAttr(),
+ op.getL2HintAttr(), op.getL3HintAttr(),
+ /**layout**/ nullptr);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+struct XeGPUSgToWiDistributeExperimentalPass
+ : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
+ XeGPUSgToWiDistributeExperimentalPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
+
+ // Verify if all XeGPU anchor ops and vector ops have result layouts.
+ // TODO: This can be removed once the full layout refactoring is done.
+ Operation *root = getOperation();
+ if (failed(verifyLayouts(root))) {
+ LLVM_DEBUG(DBGS() << "XeGPUSgToWiDistributeExperimentalPass: layout "
+ "verification failed\n");
+ signalPassFailure();
+ return;
+ }
+ // Collect existing UnrealizedConversionCastOps. These must be preserved.
+ llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
+ root->walk(
+ [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
+ // Perform a structural type conversion to convert structural ops to have WI
+ // types. This will insert UnrealizedConversionCastOps to make the IR
+ // valid.
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ UnrealizedConversionCastOp castOp =
+ UnrealizedConversionCastOp::create(builder, loc, type, inputs);
+ return castOp.getResult(0);
+ };
+ {
+ ConversionTarget target(getContext());
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(&getContext());
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+ xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+ scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ typeConverter, patterns, target);
+ target.addLegalOp<UnrealizedConversionCastOp>();
+ (void)applyPartialConversion(root, target, std::move(patterns));
+ }
+ // Structural type conversion can generate some redundant
+ // UnrealizedConversionCastOps to materialize the SG type from type converted
+ // WI type. These are redundant at this point and can be eliminated by
+ // inserting shape casts instead.
+ // Example:
+ // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
+ // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
+ // This can be replaced with:
+ // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
+ OpBuilder builder(root);
+ root->walk([&](UnrealizedConversionCastOp op) {
+ // If this op existed before, nothing to do.
+ if (existingCasts.contains(op))
+ return;
+ // number of inputs and outputs must be 1.
+ if (op.getNumOperands() != 1 || op.getNumResults() != 1)
+ return;
+ // Both input and output types must be vector types.
+ auto singleInput = op.getInputs()[0];
+ auto inputTy = dyn_cast<VectorType>(singleInput.getType());
+ auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
+ if (!inputTy || !outputTy)
+ return;
+
+ // Check if the defining op of the input is also an
+ // UnrealizedConversionCastOp and it has a single user (which is this
+ // op).
+ auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
+ if (!definingOp || !definingOp->hasOneUse())
+ return;
+ auto inputOfDefiningOp = definingOp.getInputs()[0];
+ // If the input of the defining op and output type are both vector types
+ // have same number of elements, insert a shape cast.
+ auto inputOfDefiningOpTy =
+ dyn_cast<VectorType>(inputOfDefiningOp.getType());
+ if (inputOfDefiningOpTy &&
+ inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
+ builder.setInsertionPoint(op);
+ auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
+ outputTy, inputOfDefiningOp);
+ op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
+ return;
+ }
+ });
+ // At this point, we will have some dead UnrealizedConversionCastOps. Just
+ // erase them.
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ root->walk([&](UnrealizedConversionCastOp op) {
+ // Skip existing casts.
+ if (existingCasts.contains(op))
+ return;
+ if (op.use_empty()) {
+ op.erase();
+ changed = true;
+ }
+ });
+ }
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
+ TypeConverter &typeConverter) {
+ // Any type other than TensorDescType and VectorType are legal as is.
+ typeConverter.addConversion([](Type type) -> std::optional<Type> {
+ if (!isa<TensorDescType, VectorType>(type))
+ return type;
+ return std::nullopt;
+ });
+ // For TensorDescType, drop the layout attribute if any.
+ typeConverter.addConversion([](TensorDescType type) -> Type {
+ if (type.getLayoutAttr()) {
+ return type.dropLayouts();
+ }
+ return type;
+ });
+ // For VectorType, check if there is a distribute layout attribute on the
+ // value. If so, convert to the distributed vector type based on the layout.
+ typeConverter.addConversion([](Value v) -> std::optional<Type> {
+ auto type = v.getType();
+ // If value is not vector type, nothing to do.
+ if (!isa<VectorType>(type))
+ return std::nullopt;
+ auto layout = xegpu::getDistributeLayoutAttr(v);
+ if (!layout || !layout.isForSubgroup())
+ return type;
+ // Vector type is distributed based on lane layout.
+ auto newTyOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
+ if (failed(newTyOrFailure))
+ return type;
+ return *newTyOrFailure;
+ });
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
+ ConversionTarget &target) {
+ populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+ // CreateNdDescOp is legal only if its result type has no layout attribute.
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
+ // Any anchor XeGPU op is legal only if it has no anchor layout.
+ target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
+ auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
+ if (!anchorOp)
+ return true;
+ return !anchorOp.getAnchorLayout();
+ });
+ // Arith constants are legal only if they have no temporary layout attribute.
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ // If the result type is not a vector, it's legal.
+ if (!isa<VectorType>(op.getResult().getType()))
+ return true;
+ return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+ });
+ // In math and arith dialects, only handle elementwise ops with a single
+ // result and with a result layout attribute.
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Only handle elementwise mappable ops
+ if (!OpTrait::hasElementwiseMappableTraits(op))
+ return true;
+ // Only handle ops with single vector result
+ if (op->getNumResults() != 1)
+ return true;
+
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
+ return true;
+
+ // Check if all operands are vectors of the same shape
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+ return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
+ });
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+ patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
+ SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd>(
+ typeConverter, patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c95d3cabc270f..a8ed5a289f84a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -65,48 +65,6 @@ namespace {
/// priorities to patterns.
enum PatternHierarchy : unsigned { Regular = 1, AboveRegular = 2 };
-/// Helper function to get distributed vector type for a source vector type
-/// according to the lane_layout. We simply divide each dimension of tensor
-/// descriptor shape by corresponding lane_layout dimension. If
-/// array_length > 1, that is appended to the front of the ditributed shape.
-/// NOTE: This is the vector type that will be returned by the
-/// gpu.warp_execute_on_lane0 op.
-///
-/// Examples:
-/// | original vector shape | lane_layout | distributed vector shape |
-/// |-----------------------|-------------|--------------------------|
-/// | 32x16 | [1, 16] | 32x1 |
-/// | 32x16 | [2, 8] | 16x2 |
-/// | 2x32x16 | [1, 16] | 2x32x1 |
-static FailureOr<VectorType>
-getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
- VectorType originalType) {
- if (!layout)
- return failure();
- assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
- "Expecting a valid layout.");
- SmallVector<int64_t> effectiveLaneLayout =
- layout.getEffectiveLaneLayoutAsInt();
- assert(static_cast<size_t>(originalType.getRank()) >=
- effectiveLaneLayout.size() &&
- "Rank of the original vector type should be greater or equal to the "
- "size of the lane layout to distribute the vector type.");
- SmallVector<int64_t> distributedShape(originalType.getShape());
- // Only distribute the last `laneLayout.size()` dimensions. The remaining
- // dimensions are not distributed.
- unsigned distributionStart =
- originalType.getRank() - effectiveLaneLayout.size();
- for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
- if (i < distributionStart)
- continue;
- // Check if the dimension can be distributed evenly.
- if (dim % effectiveLaneLayout[i - distributionStart] != 0)
- return failure();
- distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
- }
- return VectorType::get(distributedShape, originalType.getElementType());
-}
-
/// Helper function to resolve types if the distributed type out of
/// gpu.warp_execute_on_lane0 is
diff erent from the expected xegpu SIMT type.
/// Example 1:
@@ -145,34 +103,6 @@ static Value resolveDistributedTy(Value orig, T expected,
return orig;
}
-/// Helper function to check if the layout is packed. Layout is packed if it is
-/// 2D and lane_data[0] != 1 (data packed from col dimension).
-/// TODO: Move to target info.
-static bool requirePacked(const xegpu::LayoutAttr layout) {
- if (!layout)
- return false;
- auto laneData = layout.getEffectiveLaneDataAsInt();
- if (laneData.size() != 2)
- return false;
- return laneData[0] != 1;
-}
-
-/// Helper function to check if the layout requires a transpose effect.
-static bool requireTranspose(const xegpu::LayoutAttr layout,
- const xegpu::uArch::uArch *uArch) {
- // Return false for unsupported targets.
- // TODO: Add more support or move to target info.
- if (uArch->getName().equals_insensitive("pvc") &&
- uArch->getName().equals_insensitive("bmg"))
- return false;
- if (!layout)
- return false;
- auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
- if (laneLayout.size() != 2)
- return false;
- return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
-}
-
/// Given a vector type and its distributed vector type, return the list of
/// dimensions that are distributed.
static SmallVector<int64_t> getDistributedDims(VectorType originalType,
@@ -409,7 +339,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
storeOp, "the source tensor descriptor lacks layout attribute");
FailureOr<VectorType> distributedTypeByWarpOpOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, storeOp.getValueType());
if (failed(distributedTypeByWarpOpOrFailure))
return rewriter.notifyMatchFailure(storeOp,
"Failed to distribute the type");
@@ -575,9 +505,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
newLoadOperands, loadOp->getAttrs());
xegpu::removeLayoutAttrs(newLoadOp);
// Set the packed attribute if the layout requires it.
- newLoadOp.setPacked(requirePacked(layout));
+ newLoadOp.setPacked(xegpu::requirePacked(layout));
// Set the transpose attribute if the layout requires it.
- if (requireTranspose(layout, uArch))
+ if (xegpu::requireTranspose(layout, uArch))
newLoadOp.setTranspose(
DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 93f81c0ad71a9..7e28c756f2d72 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -16,11 +16,13 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include <cstdint>
#include <numeric>
@@ -101,6 +103,35 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
return xegpu::getDistributedVectorType(helperTdescTy);
}
+FailureOr<VectorType>
+xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
+ VectorType originalType) {
+ if (!layout)
+ return failure();
+ assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
+ "Expecting a valid layout.");
+ SmallVector<int64_t> effectiveLaneLayout =
+ layout.getEffectiveLaneLayoutAsInt();
+ assert(static_cast<size_t>(originalType.getRank()) >=
+ effectiveLaneLayout.size() &&
+ "Rank of the original vector type should be greater or equal to the "
+ "size of the lane layout to distribute the vector type.");
+ SmallVector<int64_t> distributedShape(originalType.getShape());
+ // Only distribute the last `laneLayout.size()` dimensions. The remaining
+ // dimensions are not distributed.
+ unsigned distributionStart =
+ originalType.getRank() - effectiveLaneLayout.size();
+ for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
+ if (i < distributionStart)
+ continue;
+ // Check if the dimension can be distributed evenly.
+ if (dim % effectiveLaneLayout[i - distributionStart] != 0)
+ return failure();
+ distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
+ }
+ return VectorType::get(distributedShape, originalType.getElementType());
+}
+
std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
const StringRef prefix("layout_operand_");
unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
@@ -139,7 +170,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (auto arg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = arg.getOwner()->getParentOp();
- if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+ if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
return getDistributeLayoutAttr(tiedInit->get());
@@ -731,3 +762,27 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
template int
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
ArrayRef<unsigned> candidateMultiples);
+
+bool xegpu::requirePacked(const xegpu::LayoutAttr layout) {
+ if (!layout)
+ return false;
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ if (laneData.size() != 2)
+ return false;
+ return laneData[0] != 1;
+}
+
+bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
+ const xegpu::uArch::uArch *uArch) {
+ // Return false for unsupported targets.
+ // TODO: Add more support or move to target info.
+ if (uArch->getName().equals_insensitive("pvc") &&
+ uArch->getName().equals_insensitive("bmg"))
+ return false;
+ if (!layout)
+ return false;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (laneLayout.size() != 2)
+ return false;
+ return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
+}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
new file mode 100644
index 0000000000000..0e9843f4626d4
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -0,0 +1,152 @@
+
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' --allow-unregistered-dialect \
+// RUN: --test-xegpu-sg-to-wi-distribute-experimental --split-input-file %s | FileCheck %s
+
+
+
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @create_nd_tdesc
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[TD:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
+gpu.func @create_nd_tdesc(%arg0: memref<256x256xf16>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16>
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @cerate_nd_tedesc_nonmemref_source
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[TD:.*]] = xegpu.create_nd_tdesc %{{.*}}, shape : [256, 256], strides : [256, 1] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+gpu.func @cerate_nd_tedesc_nonmemref_source(%arg0: ui64) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0, shape : [256, 256], strides : [256, 1] : ui64
+ -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @load_nd
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<16xf16> to vector<16x1xf16>
+gpu.func @load_nd() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @load_nd_packed
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<16xf16> to vector<16x1xf16>
+gpu.func @load_nd_packed() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @load_nd_transpose
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<8xf32> to vector<1x8xf32>
+gpu.func @load_nd_transpose() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x8xf32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @store_nd
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[LOAD]] : vector<16xf16> to vector<16x1xf16>
+// CHECK: %[[CAST3:.*]] = vector.shape_cast %[[CAST2]] : vector<16x1xf16> to vector<16xf16>
+// CHECK: xegpu.store_nd %[[CAST3]], %{{.*}}[%[[C0]], %[[C0]]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+gpu.func @store_nd() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %2 = xegpu.load_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+ xegpu.store_nd %2, %1[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @dpas
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK-DAG: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG: %[[CAST2:.*]] = vector.shape_cast %[[LOAD0]] : vector<8xf16> to vector<8x1xf16>
+// CHECK-DAG: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
+// CHECK-DAG: %[[CAST3:.*]] = vector.shape_cast %[[LOAD1]] : vector<16xf16> to vector<16x1xf16>
+// CHECK-DAG: %[[CAST4:.*]] = vector.shape_cast %[[CST]] : vector<8x1xf32> to vector<8xf32>
+// CHECK-DAG: %[[CAST5:.*]] = vector.shape_cast %[[CAST3]] : vector<16x1xf16> to vector<16xf16>
+// CHECK-DAG: %[[CAST6:.*]] = vector.shape_cast %[[CAST2]] : vector<8x1xf16> to vector<8xf16>
+// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[CAST6]], %[[CAST5]], %[[CAST4]] : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+// CHECK: %[[CAST7:.*]] = vector.shape_cast %[[DPAS]] : vector<8xf32> to vector<8x1xf32>
+// CHECK: gpu.return
+gpu.func @dpas() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+ %5 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ dense<0.0> : vector<8x16xf32>
+ %2 = xegpu.load_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+ %4 = xegpu.dpas %2, %3, %5
+ {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @elementwise
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x1xf32>
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}}[%[[C0]], %[[C0]]] : !xegpu.tensor_desc<16x16xf32> -> vector<16xf32>
+// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[LOAD]] : vector<16xf32> to vector<16x1xf32>
+// CHECK: %[[ADD:.*]] = arith.addf %[[CAST2]], %[[CST]] : vector<16x1xf32>
+// CHECK: gpu.return
+gpu.func @elementwise() {
+ %c0 = arith.constant 0 : index
+ %0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ dense<1.0> : vector<16x16xf32>
+ %1 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %2 = xegpu.load_nd %1[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
+ %3 = arith.addf %0, %2
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @arith_constant
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x1xf32>
+// CHECK: gpu.return
+gpu.func @arith_constant() {
+ %0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ dense<1.0> : vector<16x16xf32>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @prefetch_nd
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: xegpu.prefetch_nd %{{.*}}[%[[C0]], %[[C0]]] : !xegpu.tensor_desc<16x16xf16>
+// CHECK: gpu.return
+gpu.func @prefetch_nd() {
+ %c0 = arith.constant 0 : index
+ %0 = "some_op"() : () -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.prefetch_nd %0[%c0, %c0] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
new file mode 100644
index 0000000000000..9172cd3018b71
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
@@ -0,0 +1,219 @@
+// RUN: mlir-opt --allow-unregistered-dialect --xevm-attach-target='module=xevm_* chip=pvc' \
+// RUN: --xegpu-sg-to-wi-distribute-experimental --split-input-file %s --canonicalize --cse | FileCheck %s
+
+// CHECK-LABEL: gpu.func @gemm
+// CHECK-DAG : %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG : %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG : %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG : %[[C1024:.*]] = arith.constant 1024 : index
+// CHECK-DAG : %[[BID_X:.*]] = gpu.block_id x
+// CHECK-DAG : %[[BID_Y:.*]] = gpu.block_id y
+// CHECK-DAG : %[[MUL_X:.*]] = arith.muli %[[BID_X]], %[[C8]] : index
+// CHECK-DAG : %[[MUL_Y:.*]] = arith.muli %[[BID_Y]], %[[C16]] : index
+// CHECK : %[[TD_C:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK : %[[LOAD_C:.*]] = xegpu.load_nd %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-DAG : %[[CAST_C:.*]] = vector.shape_cast %[[LOAD_C]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-DAG : %[[TD_A:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG : %[[TD_B:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK : %[[FOR:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C1024]] step %[[C16]] iter_args(%[[ACC:.*]] = %[[CAST_C]]) -> (vector<8x1xf32>) {
+// CHECK-DAG : %[[LOAD_A:.*]] = xegpu.load_nd %[[TD_A]][%[[MUL_X]], %[[IV]]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG : %[[LOAD_B:.*]] = xegpu.load_nd %[[TD_B]][%[[IV]], %[[MUL_Y]]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG : %[[CAST_ACC:.*]] = vector.shape_cast %[[ACC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]], %[[CAST_ACC]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK : %[[CAST_DPAS:.*]] = vector.shape_cast %[[DPAS]] : vector<8xf32> to vector<8x1xf32>
+// CHECK : scf.yield %[[CAST_DPAS]] : vector<8x1xf32>
+// CHECK : } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK : %[[CAST_FOR:.*]] = vector.shape_cast %[[FOR]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : xegpu.store_nd %[[CAST_FOR]], %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK : gpu.return
+gpu.module @xevm_module{
+gpu.func @gemm(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %c1024 = arith.constant 1024 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = arith.muli %block_id_x, %c8 : index
+ %1 = arith.muli %block_id_y, %c16 : index
+ %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %3 = xegpu.load_nd %2[%0, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg0: memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+
+ %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
+ %7 = xegpu.load_nd %5[%0, %arg3]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+ %8 = xegpu.load_nd %6[%arg3, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+
+ %9 = xegpu.dpas %7, %8, %arg4
+ {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+
+ scf.yield %9 : vector<8x16xf32>
+ } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ xegpu.store_nd %4, %2[%0, %1] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}: vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @gemm_with_preop
+// CHECK-DAG : %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG : %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG : %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG : %[[C1024:.*]] = arith.constant 1024 : index
+// CHECK : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x1xbf16>
+// CHECK-DAG : %[[BID_X:.*]] = gpu.block_id x
+// CHECK-DAG : %[[BID_Y:.*]] = gpu.block_id y
+// CHECK-DAG : %[[MUL_X:.*]] = arith.muli %[[BID_X]], %[[C8]] : index
+// CHECK-DAG : %[[MUL_Y:.*]] = arith.muli %[[BID_Y]], %[[C16]] : index
+// CHECK : %[[TD_C:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK : %[[LOAD_C:.*]] = xegpu.load_nd %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-DAG : %[[CAST_C:.*]] = vector.shape_cast %[[LOAD_C]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-DAG : %[[TD_A:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG : %[[TD_B:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK : %[[FOR:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C1024]] step %[[C16]] iter_args(%[[ACC:.*]] = %[[CAST_C]]) -> (vector<8x1xf32>) {
+// CHECK-DAG : %[[LOAD_A:.*]] = xegpu.load_nd %[[TD_A]][%[[MUL_X]], %[[IV]]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK : %[[CAST_A:.*]] = vector.shape_cast %[[LOAD_A]] : vector<8xbf16> to vector<8x1xbf16>
+// CHECK : %[[PREOP:.*]] = arith.addf %[[CAST_A]], %[[CST]] : vector<8x1xbf16>
+// CHECK-DAG : %[[LOAD_B:.*]] = xegpu.load_nd %[[TD_B]][%[[IV]], %[[MUL_Y]]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG : %[[CAST_ACC:.*]] = vector.shape_cast %[[ACC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : %[[CAST_PREOP:.*]] = vector.shape_cast %[[PREOP]] : vector<8x1xbf16> to vector<8xbf16>
+// CHECK : %[[DPAS:.*]] = xegpu.dpas %[[CAST_PREOP]], %[[LOAD_B]], %[[CAST_ACC]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK : %[[CAST_DPAS:.*]] = vector.shape_cast %[[DPAS]] : vector<8xf32> to vector<8x1xf32>
+// CHECK : scf.yield %[[CAST_DPAS]] : vector<8x1xf32>
+// CHECK : } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK : %[[CAST_FOR:.*]] = vector.shape_cast %[[FOR]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : xegpu.store_nd %[[CAST_FOR]], %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+// CHECK : gpu.return
+gpu.func @gemm_with_preop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %c1024 = arith.constant 1024 : index
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.0> : vector<8x16xbf16>
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = arith.muli %block_id_x, %c8 : index
+ %1 = arith.muli %block_id_y, %c16 : index
+ %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %3 = xegpu.load_nd %2[%0, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg0: memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+
+ %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
+ %7 = xegpu.load_nd %5[%0, %arg3]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+ %preop = arith.addf %7, %cst {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>
+ %8 = xegpu.load_nd %6[%arg3, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+
+ %9 = xegpu.dpas %preop, %8, %arg4
+ {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+
+ scf.yield %9 : vector<8x16xf32>
+ } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ xegpu.store_nd %4, %2[%0, %1] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}: vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @gemm_with_postop
+// CHECK-DAG : %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG : %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG : %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG : %[[C1024:.*]] = arith.constant 1024 : index
+// CHECK-DAG : %[[BID_X:.*]] = gpu.block_id x
+// CHECK-DAG : %[[BID_Y:.*]] = gpu.block_id y
+// CHECK-DAG : %[[MUL_X:.*]] = arith.muli %[[BID_X]], %[[C8]] : index
+// CHECK-DAG : %[[MUL_Y:.*]] = arith.muli %[[BID_Y]], %[[C16]] : index
+// CHECK : %[[TD_C:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK : %[[LOAD_C:.*]] = xegpu.load_nd %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+// CHECK-DAG : %[[CAST_C:.*]] = vector.shape_cast %[[LOAD_C]] : vector<8xf32> to vector<8x1xf32>
+// CHECK-DAG : %[[TD_A:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+// CHECK-DAG : %[[TD_B:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
+// CHECK : %[[FOR:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C1024]] step %[[C16]] iter_args(%[[ACC:.*]] = %[[CAST_C]]) -> (vector<8x1xf32>) {
+// CHECK-DAG : %[[LOAD_A:.*]] = xegpu.load_nd %[[TD_A]][%[[MUL_X]], %[[IV]]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
+// CHECK-DAG : %[[LOAD_B:.*]] = xegpu.load_nd %[[TD_B]][%[[IV]], %[[MUL_Y]]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
+// CHECK-DAG : %[[CAST_ACC:.*]] = vector.shape_cast %[[ACC]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]], %[[CAST_ACC]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
+// CHECK : %[[CAST_DPAS:.*]] = vector.shape_cast %[[DPAS]] : vector<8xf32> to vector<8x1xf32>
+// CHECK : scf.yield %[[CAST_DPAS]] : vector<8x1xf32>
+// CHECK : } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK : %[[POSTOP:.*]] = math.exp %[[FOR]] : vector<8x1xf32>
+// CHECK : %[[CAST_POSTOP:.*]] = vector.shape_cast %[[POSTOP]] : vector<8x1xf32> to vector<8xf32>
+// CHECK : xegpu.store_nd %[[CAST_POSTOP]], %[[TD_C]][%[[MUL_X]], %[[MUL_Y]]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+gpu.func @gemm_with_postop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %c1024 = arith.constant 1024 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = arith.muli %block_id_x, %c8 : index
+ %1 = arith.muli %block_id_y, %c16 : index
+ %2 = xegpu.create_nd_tdesc %arg2 : memref<1024x1024xf32> ->
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %3 = xegpu.load_nd %2[%0, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+ %5 = xegpu.create_nd_tdesc %arg0: memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<1024x1024xbf16>
+ -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+
+ %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
+ %7 = xegpu.load_nd %5[%0, %arg3]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
+ %8 = xegpu.load_nd %6[%arg3, %1]
+ {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
+
+ %9 = xegpu.dpas %7, %8, %arg4
+ {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+ layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
+
+ scf.yield %9 : vector<8x16xf32>
+ } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ %postop = math.exp %4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>
+ xegpu.store_nd %postop, %2[%0, %1] {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}: vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+}
+
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c8a6a6d7b8eb8..405e974500e08 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -6,16 +6,23 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
using namespace mlir;
using namespace mlir::xegpu;
@@ -247,6 +254,57 @@ struct TestXeGPUSGDistribute
}
};
+/// This test pass is intended to test the subgroup to workitem distribution of
+/// xegpu/vector/arith operations in isolation, it does not handle any
+/// structural ops like scf.for etc.
+struct TestXeGPUSgToWiDistributeExperimental
+ : public PassWrapper<TestXeGPUSgToWiDistributeExperimental,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestXeGPUSgToWiDistributeExperimental)
+
+ StringRef getArgument() const final {
+ return "test-xegpu-sg-to-wi-distribute-experimental";
+ }
+
+ StringRef getDescription() const final {
+ return "Test the experimental implementation of XeGPU Subgroup to "
+ "Work-item Distribution";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<index::IndexDialect>();
+ registry.insert<gpu::GPUDialect>();
+ }
+
+ TestXeGPUSgToWiDistributeExperimental() = default;
+ TestXeGPUSgToWiDistributeExperimental(
+ const TestXeGPUSgToWiDistributeExperimental &pass) = default;
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ TypeConverter typeConverter;
+ // Define type materializations using UnrealizedConversionCastOp.
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
+ .getResult(0);
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+ ConversionTarget target(*ctx);
+ RewritePatternSet patterns(ctx);
+ xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+ typeConverter, patterns, target);
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+ }
+};
+
struct TestXeGPUMoveFuncBodyToWarpOp
: public PassWrapper<TestXeGPUMoveFuncBodyToWarpOp,
OperationPass<gpu::GPUModuleOp>> {
@@ -415,6 +473,7 @@ void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
PassRegistration<TestXeGPULayoutInterface>();
PassRegistration<TestXeGPUSGDistribute>();
+ PassRegistration<TestXeGPUSgToWiDistributeExperimental>();
PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();
PassRegistration<TestXeGPUPropagateLayouts>();
PassRegistration<TestXeGPUResolveLayoutConflicts>();
More information about the Mlir-commits
mailing list