[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)
Charitha Saumya
llvmlistbot at llvm.org
Tue Aug 26 17:01:23 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/155517
>From 28c5c4c5f29a23dee72e9397e0f93063dc167e75 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 16:11:50 +0000
Subject: [PATCH 01/15] pull changes
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 159 ++++++++++++++-
.../Transforms/XeGPUSubgroupDistribute.cpp | 188 +++++++++++++++++-
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 17 ++
3 files changed, 353 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef88042fc663..10c2759493477 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -62,10 +62,17 @@ struct Layout {
SmallVector<int64_t, 3> layout;
Layout() = default;
Layout(std::initializer_list<int64_t> list) : layout(list) {}
+ Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
void print(llvm::raw_ostream &os) const;
size_t size() const { return layout.size(); }
+ int64_t operator[](size_t idx) const;
};
+int64_t Layout::operator[](size_t idx) const {
+ assert(idx < layout.size() && "Index out of bounds");
+ return layout[idx];
+}
+
void Layout::print(llvm::raw_ostream &os) const {
os << llvm::interleaved_array(layout);
}
@@ -324,6 +331,13 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+ void visitShapeCastOp(vector::ShapeCastOp shapeCast,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
@@ -383,6 +397,12 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
visitVectorMultiReductionOp(reductionOp, operands, results);
})
+ .Case<vector::BroadcastOp>([&](auto broadcastOp) {
+ visitVectorBroadCastOp(broadcastOp, operands, results);
+ })
+ .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
+ visitShapeCastOp(shapeCastOp, operands, results);
+ })
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
@@ -437,6 +457,83 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
}
+void LayoutInfoPropagation::visitVectorBroadCastOp(
+ vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ // Only consider 1D -> 2D broadcasts or 2D -> 2D broadcasts.
+ VectorType resultTy = broadcast.getResultVectorType();
+ VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceTy) {
+ broadcast.emitWarning("Expecting source type to be a vector type.");
+ return;
+ }
+
+ // Only conside 2D -> 2D broadcast.
+ if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) {
+ broadcast.emitWarning("Expecting source type to be 2D vector and "
+ "result type to be 2D vector.");
+ return;
+ }
+ SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
+ if (broadcastUnitDims.size() != 1) {
+ broadcast.emitWarning("Expecting source type to be 2D vector only with "
+ "one broadcasted dimension.");
+ return;
+ }
+ // Propagate the result layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+void LayoutInfoPropagation::visitShapeCastOp(
+ vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ // The layout of the result must be present.
+ LayoutInfo resultLayout = results[0]->getValue();
+ if (!resultLayout.isAssigned())
+ return;
+ VectorType sourceTy = shapeCast.getSourceVectorType();
+ VectorType resultTy = shapeCast.getResultVectorType();
+ // Expecting source rank to be 1D or 2D.
+ if (sourceTy.getRank() != 1 && sourceTy.getRank() != 2) {
+ shapeCast.emitWarning("Expecting source type to be 1D or 2D vector.");
+ return;
+ }
+ // Expecting result rank to be 1D or 2D.
+ if (resultTy.getRank() != 1 && resultTy.getRank() != 2) {
+ shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
+ return;
+ }
+ // For 2D -> 2D shape cast, propagate the result layout to the source.
+ if (sourceTy.getRank() == 2 && resultTy.getRank() == 2) {
+ // Propagate the result layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+ return;
+ }
+ auto resultLayoutArray = resultLayout.getLayoutAsArrayRef();
+ if (resultLayoutArray[0] != 1 && resultLayoutArray[1] != 1) {
+ shapeCast.emitWarning(
+ "Expecting result layout to be of form [1, subgroupSize] "
+ "or [subgroupSize, 1].");
+ return;
+ }
+ int64_t distributedDim = resultLayoutArray[0] == 1 ? 1 : 0;
+ // If the result shape can be evenly distributed in the distributed dimension,
+ // then the source layout should be [subgroupSize][1]. Otherwise, data is
+ // shared accross lanes (broadcasted). In that case, just assign [1][1] for
+ // now (TODO: Use slice for this case)
+ LayoutInfo sourceLayout =
+ resultTy.getShape()[distributedDim] % xegpu::targetinfo::subgroupSize == 0
+ ? LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
+ LaneData({1}))
+ : LayoutInfo(LaneLayout({1}), LaneData({1}));
+ // Propagate the source layout to the source operand.
+ propagateIfChanged(operands[0], operands[0]->meet(sourceLayout));
+}
+
/// Propagate the layout of the result tensor to the source tensor descriptor in
/// UpdateNdOffsetOp.
void LayoutInfoPropagation::visitUpdateNdOffsetOp(
@@ -529,16 +626,64 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
int outElemTyBitWidth =
bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
- // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
- // a warning and return.
- if (inElemTyBitWidth != outElemTyBitWidth) {
- bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
- "layout propagation stage.");
+ // If the element bit widths are the same, then the layout does not change.
+ if (inElemTyBitWidth == outElemTyBitWidth) {
+ propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
return;
}
+ int64_t rank = bitcast.getSourceVectorType().getRank();
+ // Bitcast is a `narrowing` if the input element type bit width larger than
+ // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
+ bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
+ int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
+ : outElemTyBitWidth / inElemTyBitWidth;
+ const LaneLayout &sourceLaneLayout =
+ resultLayout.getLayout(); // source lane layout is unchanged.
+ ArrayRef<int64_t> currData = resultLayout.getDataAsArrayRef();
+
+ // TODO: Currently we assume that bitcasts does not require cross lane
+ // communication. So each lane must own the required number of elements to
+ // perform the bitcast locally without cross-lane communication.
+ // For 1D vectors, decide how many elements each lane owns based on whether
+ // the bitcast is narrowing or widening.
+ if (rank == 1) {
+ if ((currData[0] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
+ bitcast.emitWarning(
+ "Narrowing bitcast with cross lane communication is not supported.");
+ return;
+ }
+ LaneData sourceLaneData = isNarrowing
+ ? LaneData({currData[0] / bitCastRatio})
+ : LaneData({currData[0] * bitCastRatio});
- propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
+ sourceLaneLayout, sourceLaneData)));
+ }
+ // For nD vectors, Each lane is not allowed to own multiple elements in any
+ // dimension other than the innermost dimension.
+ // TODO: Add support for other case depending on the use case.
+ SmallVector<int64_t, 3> sourceLaneDataStorage(currData.begin(),
+ currData.end() - 1);
+ if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
+ bitcast.emitWarning(
+ "Each lane must not own multiple elements in any dimension other than "
+ "the innermost dimension.");
+ return;
+ }
+ // Check if the bitcast requires cross lane communication.
+ if ((currData[rank - 1] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
+ bitcast.emitWarning(
+ "Narrowing bitcast with cross lane communication is not supported.");
+ return;
+ }
+ // Decide lane data based on whether the bitcast is narrowing or widening.
+ int64_t innerMostLaneData = isNarrowing ? currData[rank - 1] / bitCastRatio
+ : currData[rank - 1] * bitCastRatio;
+ sourceLaneDataStorage.push_back(innerMostLaneData);
+ LaneData sourceLaneData(sourceLaneDataStorage);
+
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
+ sourceLaneLayout, sourceLaneData)));
}
/// Propagate the layout of the result to the tensor descriptor and mask
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..61eece55a9bac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -27,6 +27,7 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -34,6 +35,9 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
namespace mlir {
namespace xegpu {
@@ -146,6 +150,15 @@ static bool hasPackedLayout(xegpu::LayoutAttr layout) {
return laneData.asArrayRef()[0] != 1;
}
+static bool hasTransposedLayout(xegpu::LayoutAttr layout) {
+ if (layout == xegpu::LayoutAttr())
+ return false;
+ DenseI32ArrayAttr laneLayout = layout.getLaneLayout();
+ if (!laneLayout || laneLayout.size() != 2)
+ return false;
+ return laneLayout.asArrayRef()[0] > 1 && laneLayout.asArrayRef()[1] == 1;
+}
+
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
/// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
/// contained within a WarpExecuteOnLane0Op.
@@ -500,6 +513,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
xegpu::removeLayoutAttrs(newLoadOp);
// Set the packed attribute if the layout requires it.
newLoadOp.setPacked(hasPackedLayout(layout));
+ if (hasTransposedLayout(layout))
+ newLoadOp.setTranspose(
+ DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
Value distributedVal = newWarpOp.getResult(operandIdx);
// There can be a conflict between the vector type distributed by the
// warp op and (xegpu-specific) distributed type supported by the load
@@ -811,6 +827,135 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+struct MemrefExtractAlignedPointerAsIndexDistribution final
+ : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand = getWarpResult(
+ warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp,
+ "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op");
+ auto extractOp =
+ operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, extractOp.getSource(),
+ TypeRange{extractOp.getSource().getType()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, newWarpOp.getLoc(), extractOp.getType(),
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+ return success();
+ }
+};
+
+struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::BitCast op");
+ auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ VectorType distributedSourceType =
+ getDistVecTypeBasedOnLaneLayout(
+ xegpu::getLayoutAttr(bitcastOp.getSource()),
+ bitcastOp.getSourceVectorType())
+ .value_or(VectorType());
+ if (!distributedSourceType)
+ return rewriter.notifyMatchFailure(
+ bitcastOp, "Failed to distribute the source vector type in "
+ "vector::BitCast op");
+ VectorType distributedResultType =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (distributedSourceType.getRank() != 2 ||
+ distributedResultType.getRank() != 2)
+ return rewriter.notifyMatchFailure(
+ bitcastOp, "the source or result vector of the bitcast op "
+ "are not 2D vectors");
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, bitcastOp.getSource(),
+ TypeRange{distributedSourceType}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newBitcastOp = vector::BitCastOp::create(
+ rewriter, newWarpOp.getLoc(), distributedResultType,
+ newWarpOp.getResult(newRetIndices[0]));
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
+ return success();
+ }
+};
+
+struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *operand =
+ getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
+ if (!operand)
+ return rewriter.notifyMatchFailure(
+ warpOp, "warp result is not a vector::Transpose op");
+ auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
+ unsigned operandIdx = operand->getOperandNumber();
+ xegpu::LayoutAttr sourceLayout =
+ xegpu::getLayoutAttr(transposeOp.getVector());
+ xegpu::LayoutAttr resultLayout =
+ xegpu::getLayoutAttr(transposeOp.getResult());
+ if (!sourceLayout || !resultLayout)
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "the source or result vector of the transpose op lacks layout "
+ "attribute");
+ ArrayRef<int> sourceLaneLayout = sourceLayout.getLaneLayout().asArrayRef();
+ ArrayRef<int> resultLaneLayout = resultLayout.getLaneLayout().asArrayRef();
+ ArrayRef<int> sourceLaneData = sourceLayout.getLaneData().asArrayRef();
+ ArrayRef<int> resultLaneData = resultLayout.getLaneData().asArrayRef();
+ if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
+ return rewriter.notifyMatchFailure(
+ transposeOp, "the source or result vector of the transpose op "
+ "does not have 2D layout");
+ auto is2DTranspose = [](ArrayRef<int> input, ArrayRef<int> output) {
+ return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
+ input[1] == output[0];
+ };
+
+ if (!is2DTranspose(sourceLaneLayout, resultLaneLayout) ||
+ !is2DTranspose(sourceLaneData, resultLaneData))
+ return rewriter.notifyMatchFailure(
+ transposeOp,
+ "the source or result vector layouts must be transposes of each "
+ "other");
+ FailureOr<VectorType> distributedSourceTypeOrFailure =
+ getDistVecTypeBasedOnLaneLayout(sourceLayout,
+ transposeOp.getSourceVectorType());
+ if (failed(distributedSourceTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "Failed to distribute the source vector type in "
+ "vector::Transpose op");
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, transposeOp.getVector(),
+ TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
+ rewriter.setInsertionPointAfter(newWarpOp);
+ auto newTransposeOp = vector::TransposeOp::create(
+ rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
+ transposeOp.getPermutation());
+ Value distributedVal = newWarpOp.getResult(operandIdx);
+ rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -825,7 +970,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- UpdateNdOffsetDistribution, GpuBarrierDistribution>(
+ UpdateNdOffsetDistribution, GpuBarrierDistribution,
+ VectorTransposeDistribution, VectorBitcastDistribution,
+ MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext());
}
@@ -903,14 +1050,47 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
int64_t warpSz) { return Value(); };
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);
+
+ auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
+ vector::CombiningKind kind, uint32_t size) {
+ // First reduce on a single thread to get per lane reduction value.
+ Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ // Parallel reduction using butterfly shuffles.
+ for (uint64_t i = 1; i < size; i <<= 1) {
+ Value shuffled =
+ builder
+ .create<gpu::ShuffleOp>(loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+ }
+ return laneVal;
+ };
+
+ vector::populateDistributeReduction(patterns, warpReduction);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
return;
}
- // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
+ // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
// due to tensor desc type mismatches created by using upstream distribution
- // patterns (scf.for)
+ // patterns (scf.for). This cleanup should only be done if all the ops are
+ // distributed successfully, if some ops are still not distributed and remains
+ // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
+ // breaking the IR.
+ bool foundWarpOp = false;
+ getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
+ // Look for WarpOps that are not trivially dead.
+ if (isOpTriviallyDead(warpOp))
+ return WalkResult::advance();
+ foundWarpOp = true;
+ return WalkResult::interrupt();
+ });
+ if (foundWarpOp)
+ return;
+
getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
// We are only interested in UnrealizedConversionCastOps there were added
// for resolving SIMT type mismatches.
@@ -929,7 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
"Unrealized conversion cast must have tensor descriptor types");
// tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
- // This occurs iside scf.for body to resolve the block argument type to
+ // This occurs inside scf.for body to resolve the block argument type to
// SIMT type.
if (inputDescType.getLayout()) {
auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 0214d84f2c16f..4cbe4db271ad6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -181,6 +181,23 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
return
}
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x8xi32> -> vector<16x8xi32>
+ %4 = vector.bitcast %3 : vector<16x8xi32> to vector<16x16xf16>
+ %5 = vector.transpose %4, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+ %6 = xegpu.dpas %2, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ xegpu.store_nd %6, %7 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+ return
+}
+
// -----
// CHECK-LABEL: func.func @binary_op_one_use(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
>From ad5d0a88a4f065dc3720d977c8e3d125c5b768b8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 17:58:25 +0000
Subject: [PATCH 02/15] rename getLayoutAttr util
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 66 +++++++++++++++++++
.../mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 2 +-
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 27 ++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 25 ++++---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 16 ++---
.../Transforms/XeGPUSubgroupDistribute.cpp | 5 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 26 ++++----
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 30 ++++-----
8 files changed, 132 insertions(+), 65 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index b4d696444cc44..5b4b376157c00 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -185,6 +185,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Check the availability of workgroup level layouts",
"bool",
"isForWorkgroup">,
+ InterfaceMethod<"Check the availability of subgroup level layouts",
+ "bool",
+ "isForSubgroup">,
InterfaceMethod<"Get the rank of attribute",
"int64_t",
"getRank">,
@@ -202,6 +205,15 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Get the SgData field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgDataAsInt">,
+ InterfaceMethod<"Get the InstData field of the attribute as integer array",
+ "std::optional<SmallVector<int64_t>>",
+ "getInstDataAsInt">,
+ InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
+ "std::optional<SmallVector<int64_t>>",
+ "getLaneLayoutAsInt">,
+ InterfaceMethod<"Get the LaneData field of the attribute as integer array",
+ "std::optional<SmallVector<int64_t>>",
+ "getLaneDataAsInt">,
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
"xegpu::DistributeLayoutAttr",
"dropSgLayoutAndData">,
@@ -388,6 +400,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return std::nullopt;
}
+ std::optional<SmallVector<int64_t>> getInstDataAsInt() const {
+ if (DenseI32ArrayAttr inst = getInstData())
+ return llvm::to_vector_of<int64_t>(inst.asArrayRef());
+ return std::nullopt;
+ }
+
+ std::optional<SmallVector<int64_t>> getLaneLayoutAsInt() const {
+ if (DenseI32ArrayAttr layout = getLaneLayout())
+ return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+ return std::nullopt;
+ }
+
+ std::optional<SmallVector<int64_t>> getLaneDataAsInt() const {
+ if (DenseI32ArrayAttr data = getLaneData())
+ return llvm::to_vector_of<int64_t>(data.asArrayRef());
+ return std::nullopt;
+ }
+
/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
@@ -488,6 +518,42 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return std::nullopt;
}
+ /// Returns the InstData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ std::optional<SmallVector<int64_t>> getInstDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ if (auto inst = parent.getInstDataAsInt()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*inst), dims);
+ }
+ return std::nullopt;
+ }
+
+ /// Returns the LaneLayout of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ std::optional<SmallVector<int64_t>> getLaneLayoutAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ if (auto layout = parent.getLaneLayoutAsInt()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
+ }
+ return std::nullopt;
+ }
+
+ /// Returns the LaneData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ std::optional<SmallVector<int64_t>> getLaneDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ if (auto data = parent.getLaneDataAsInt()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
+ }
+ return std::nullopt;
+ }
+
SliceAttr dropSgLayoutAndData() {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 76d58e5ea2424..c173b93face98 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -40,7 +40,7 @@ def XeGPU_Dialect : Dialect {
let extraClassDeclaration = [{
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
- static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+ static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
/// drops/slices the shape in the specified dims, and return the rest. e.g.,
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index b2b2d3ab85231..010199083add9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -21,6 +21,7 @@ class ValueRange;
class TypeConverter;
namespace xegpu {
+class DistributeLayoutAttr;
class LayoutAttr;
class TensorDescType;
} // namespace xegpu
@@ -60,22 +61,22 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
LayoutAttr layout);
-/// Return the attribute name for the OpOperand to attach LayoutAttr
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
std::string getLayoutName(const OpOperand &operand);
-/// Return the attribute name for the OpResult to attach LayoutAttr
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
std::string getLayoutName(const OpResult result);
-/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
-/// values, the LayoutAttr is extracted from the TensorDescType itself. For
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For TensorDescType
+/// values, the DistributeLayoutAttr is extracted from the TensorDescType itself. For
/// other values, it is obtained from the attributes of the defining operation.
-/// Returns nullptr if no LayoutAttr is found.
-LayoutAttr getLayoutAttr(const Value value);
+/// Returns nullptr if no DistributeLayoutAttr is found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-/// Retrieves the LayoutAttr associated with a given OpOperand. It will
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
/// first check the operand_layout_{id} of the owner operation. If not found,
/// it will check the operand itself and its defining op.
-LayoutAttr getLayoutAttr(const OpOperand &opr);
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
@@ -83,23 +84,23 @@ template <typename T,
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);
-/// Removes the LayoutAttr for each OpOperand and OpResult of the given
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given
/// operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
-/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
+/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
/// it to the owner's dictionary attributes
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
-void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
+void setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout);
-/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
+/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
/// If the operation contains regions, it is also applied recursively to the
/// contained operations
void setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl);
+ function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
/// Extract a set of small vectors from a value with a given shape using
/// vector.extract_stride_slice
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index a2d708be0e937..2079848c878a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -91,7 +91,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc,
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::LayoutAttr attr) {
+ xegpu::DistributeLayoutAttr attr) {
assert(attr && "Layout attribute is missing.");
// Checks whether the given shape can be evenly distributed using the
@@ -104,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// smaller than `layout[i] * data[i]`, allowing multiple compute units to
// share the data.
auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+ std::optional<SmallVector<int64_t>> layout,
+ std::optional<SmallVector<int64_t>> data,
bool rr = true) -> optional<SmallVector<int64_t>> {
llvm::SmallVector<int64_t> newShape(shape);
if (layout) {
- auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
- if (vec.size() != shape.size())
+ if ((*layout).size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(shape, vec);
+ auto ratio = computeShapeRatio(shape, *layout);
if (!ratio.has_value())
return std::nullopt;
newShape = ratio.value();
}
if (data) {
- auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
- if (vec.size() != shape.size())
+ if ((*data).size() != shape.size())
return std::nullopt;
- auto ratio = computeShapeRatio(newShape, vec);
+ auto ratio = computeShapeRatio(newShape, *data);
if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(vec, newShape);
+ ratio = computeShapeRatio(*data, newShape);
if (!ratio.has_value())
return std::nullopt;
// if data is not null, we always return it for next phase.
- newShape = vec;
+ newShape = *data;
}
return newShape;
};
// check the sgLayout and sgData
auto maybeSgShape =
- tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+ tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
if (!maybeSgShape)
return false;
auto sgShape = maybeSgShape.value();
// check InstData, it neither have layout nor need round-robin
auto maybeInstShape =
- tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+ tryDistribute(sgShape, std::nullopt, attr.getInstDataAsInt(), false);
if (!maybeInstShape)
return false;
auto instShape = maybeInstShape.value();
// check LaneLayout and LaneData
auto maybeLaneShape =
- tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+ tryDistribute(instShape, attr.getLaneLayoutAsInt(), attr.getLaneDataAsInt(), false);
return maybeLaneShape.has_value();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index b3144e4c1e55d..c62597df1f895 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -140,10 +140,10 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
else
value = (Value)operandOrResult;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult);
if (layout && layout.isForSubgroup()) {
- if (auto inst_data = layout.getInstData())
- return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+ if (auto inst_data = layout.getInstDataAsInt())
+ return inst_data.value();
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -204,12 +204,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
// skip the op if any of its operands or results has workgroup level layouts
bool hasWgLayoutOperands =
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr);
return layout && layout.isForWorkgroup();
});
bool hasWgLayoutResults =
llvm::any_of(op->getOpResults(), [](OpResult result) {
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result);
return layout && layout.isForWorkgroup();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
@@ -220,8 +220,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
Type valTy = value.getType();
if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
- xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
- return layout && layout.getInstData();
+ xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
+ return layout && layout.getInstDataAsInt();
}
auto shapedType = dyn_cast<ShapedType>(valTy);
return shapedType && !llvm::equal(tileShape, shapedType.getShape());
@@ -247,7 +247,7 @@ void XeGPUBlockingPass::runOnOperation() {
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
- xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+ xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
xegpu::LayoutAttr layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..de9378bd7a6f6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -841,7 +841,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+ auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(operand));
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
@@ -882,7 +882,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (vecRank == 0)
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+ // TODO: support more layout types
+ auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(val));
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 93b4efcd125ec..c60f9e361bf8e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -406,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
if (resultTy.getRank() != 2)
return failure();
- auto originalLayout = xegpu::getLayoutAttr(op.getResult());
+ auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
if (!originalLayout)
return failure();
@@ -470,8 +470,8 @@ struct WgToSgVectorBroadcastOp
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
return failure();
// TODO: Currently only supports cases where the source and result ranks
@@ -487,8 +487,8 @@ struct WgToSgVectorBroadcastOp
// Check if the output layout is distributable
SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+ if (auto maybeSgLayout = layout.getSgLayoutAsInt())
+ sgLayout = *maybeSgLayout;
else
return failure();
@@ -535,8 +535,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ if (!layout || !layout.isForWorkgroup())
return failure();
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -737,8 +737,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
if (!vecAttr || !vecAttr.isSplat() || !vecType)
return failure();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- if (!layout || !layout.getSgLayout())
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
return failure();
ArrayRef<int64_t> wgShape = vecType.getShape();
@@ -928,7 +928,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
});
target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
- auto layout = xegpu::getLayoutAttr(op.getResult());
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
return isLegal(layout);
});
@@ -947,12 +947,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecType)
return true;
- return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
@@ -980,7 +980,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6835f64ad8ef7..5ae025ef34739 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -114,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) {
return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
if (!value)
return nullptr;
@@ -132,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
// for LoadNdOp, the layout is stored in the tensor descriptor
if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
- return getLayoutAttr(loadNd.getTensorDesc());
+ return getDistributeLayoutAttr(loadNd.getTensorDesc());
std::string layoutName = getLayoutName(result);
if (defOp->hasAttr(layoutName))
- return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+ return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -144,41 +144,41 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
- return getLayoutAttr(tiedInit->get());
+ return getDistributeLayoutAttr(tiedInit->get());
}
}
return nullptr;
}
-xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
- return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
- return getLayoutAttr(opr.get());
+ return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+ return getDistributeLayoutAttr(opr.get());
}
template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+ if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->setAttr(name, layout);
}
// Explicit instantiation for OpResult
template void
xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
- const mlir::xegpu::LayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout);
// Explicit instantiation for OpOperand
template void
xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
- const mlir::xegpu::LayoutAttr layout);
+ const mlir::xegpu::DistributeLayoutAttr layout);
void xegpu::setLayoutAttrs(Operation *op,
- function_ref<LayoutAttr(Value)> getLayoutImpl) {
+ function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getLayoutImpl(opr.get());
@@ -195,7 +195,7 @@ template <typename T, typename>
void xegpu::removeLayoutAttr(const T &operandOrResult) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
- if (owner->hasAttrOfType<LayoutAttr>(name))
+ if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
owner->removeAttr(name);
}
@@ -306,7 +306,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
if (!inputTy || !resultTy)
return WalkResult::skip();
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+ xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(input);
if (!layout)
return WalkResult::skip();
@@ -344,7 +344,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
}
{ // perform the conversion from RankedTensorType to VectorType based on the
- // LayoutAttr
+ // DistributeLayoutAttr
// Handle the UnrealizedConversionCastOp introduced by the first step.
// For vector->RankedTensorType, it will simply forward the inputs.
>From 0e34f36690a34f071afd181649b8f86c90dde9b4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:10:49 +0000
Subject: [PATCH 03/15] refine
---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 17 +++++++++++---
.../XeGPU/Transforms/XeGPUBlocking.cpp | 5 ++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 ++--
.../Transforms/XeGPUSubgroupDistribute.cpp | 7 +++---
.../Transforms/XeGPUWgToSgDistribute.cpp | 10 ++++----
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 23 ++++++++++---------
6 files changed, 40 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 010199083add9..7089559d0c51b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -73,11 +73,21 @@ std::string getLayoutName(const OpResult result);
/// Returns nullptr if no DistributeLayoutAttr is found.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const Value value) {
+ return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
+}
+
/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
/// first check the operand_layout_{id} of the owner operation. If not found,
/// it will check the operand itself and its defining op.
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) {
+ return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(opr));
+}
+
/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
@@ -94,13 +104,14 @@ void removeLayoutAttrs(Operation *op);
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
std::is_same_v<T, OpResult>>>
-void setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout);
+void setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout);
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
/// If the operation contains regions, it is also applied recursively to the
/// contained operations
-void setLayoutAttrs(Operation *op,
- function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
+void setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
/// Extract a set of small vectors from a value with a given shape using
/// vector.extract_stride_slice
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index c62597df1f895..2e3e40ed2d457 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -247,7 +247,8 @@ void XeGPUBlockingPass::runOnOperation() {
// Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
// This ensures that the LayoutAttr remains accessible even if the defining
// operation is replaced.
- xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
+ xegpu::setDistributeLayoutAttrs(
+ op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
xegpu::LayoutAttr layout) {
@@ -377,7 +378,7 @@ void XeGPUBlockingPass::runOnOperation() {
if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
op->removeAttr(name);
if (!isa<LoopLikeOpInterface>(op))
- xegpu::setLayoutAttr(result, layout.dropInstData());
+ xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
}
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef88042fc663..5cb47b2accd68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
// If the result is a vector type, add a temporary layout attribute to the
// op.
- xegpu::setLayoutAttr(result, layout);
+ xegpu::setDistributeLayoutAttr(result, layout);
}
return success();
}
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// If the type is a vector type and this region argument is an OpResult,
// set the layout attribute on the OpResult.
if (auto result = dyn_cast<OpResult>(successorInput))
- xegpu::setLayoutAttr(result, successorOperandLayout);
+ xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
}
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index de9378bd7a6f6..e48e2180197ec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -841,14 +841,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
if (!isa<VectorType>(operand.get().getType()))
continue;
- auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(operand));
+ auto layout =
+ xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
if (!layout) {
op->emitError("Could not find layout attribute for operand ")
<< operand.getOperandNumber() << " of operation " << op->getName();
signalPassFailure();
return;
}
- xegpu::setLayoutAttr(operand, layout);
+ xegpu::setDistributeLayoutAttr(operand, layout);
}
});
// Step 2: Move all operations of a GPU function inside
@@ -883,7 +884,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
// TODO: support more layout types
- auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(val));
+ auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
// If no layout is specified, assume the inner most dimension is distributed
// for now.
if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c60f9e361bf8e..a8700ca73efc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -429,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
resultTy.getElementType());
tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
- xegpu::setLayoutAttr(cast<OpResult>(tmpC),
- originalLayout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
+ originalLayout.dropSgLayoutAndData());
newDpasOps.push_back(tmpC);
}
@@ -508,8 +508,8 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -755,7 +755,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (auto newLayout = layout.dropSgLayoutAndData())
- xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
SmallVector<Value> newConsts(count, cstOp);
rewriter.replaceOpWithMultiple(op, {newConsts});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 5ae025ef34739..1d4de68754c20 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -160,7 +160,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr)
}
template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout) {
+void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
+ const DistributeLayoutAttr layout) {
Operation *owner = operandOrResult.getOwner();
std::string name = xegpu::getLayoutName(operandOrResult);
if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
@@ -168,25 +169,25 @@ void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr l
}
// Explicit instantiation for OpResult
-template void
-xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
- const mlir::xegpu::DistributeLayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
+ const mlir::OpResult &result,
+ const mlir::xegpu::DistributeLayoutAttr layout);
// Explicit instantiation for OpOperand
-template void
-xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
- const mlir::xegpu::DistributeLayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
+ const mlir::OpOperand &operand,
+ const mlir::xegpu::DistributeLayoutAttr layout);
-void xegpu::setLayoutAttrs(Operation *op,
- function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
+void xegpu::setDistributeLayoutAttrs(
+ Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
op->walk([&](Operation *nestOp) {
for (OpOperand &opr : nestOp->getOpOperands()) {
auto layout = getLayoutImpl(opr.get());
- setLayoutAttr(opr, layout);
+ setDistributeLayoutAttr(opr, layout);
}
for (OpResult result : nestOp->getOpResults()) {
auto layout = getLayoutImpl(result);
- setLayoutAttr(result, layout);
+ setDistributeLayoutAttr(result, layout);
}
});
}
>From a84014ff42002dc5b036558c62e5387536e74019 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:12:17 +0000
Subject: [PATCH 04/15] format
---
.../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 25 ++++++++++---------
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 4 +--
.../XeGPU/Transforms/XeGPUBlocking.cpp | 9 ++++---
.../Transforms/XeGPUWgToSgDistribute.cpp | 12 ++++++---
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 6 +++--
5 files changed, 33 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 7089559d0c51b..82fd70571c022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -67,10 +67,11 @@ std::string getLayoutName(const OpOperand &operand);
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
std::string getLayoutName(const OpResult result);
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For TensorDescType
-/// values, the DistributeLayoutAttr is extracted from the TensorDescType itself. For
-/// other values, it is obtained from the attributes of the defining operation.
-/// Returns nullptr if no DistributeLayoutAttr is found.
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
template <typename AttrTy>
@@ -78,9 +79,9 @@ AttrTy getDistributeLayoutAttrOfType(const Value value) {
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
}
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
-/// first check the operand_layout_{id} of the owner operation. If not found,
-/// it will check the operand itself and its defining op.
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
template <typename AttrTy>
@@ -94,8 +95,8 @@ template <typename T,
std::is_same_v<T, OpResult>>>
void removeLayoutAttr(const T &operandOrResult);
-/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given
-/// operation if they exist. If the operation contains regions, it is also
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
+/// given operation if they exist. If the operation contains regions, it is also
/// applied recursively to the contained operations
void removeLayoutAttrs(Operation *op);
@@ -107,9 +108,9 @@ template <typename T,
void setDistributeLayoutAttr(const T &operandOrResult,
const DistributeLayoutAttr layout);
-/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
-/// If the operation contains regions, it is also applied recursively to the
-/// contained operations
+/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
+/// operation. If the operation contains regions, it is also applied recursively
+/// to the contained operations
void setDistributeLayoutAttrs(
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 2079848c878a3..6de6049facfc6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -147,8 +147,8 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
auto instShape = maybeInstShape.value();
// check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getLaneLayoutAsInt(), attr.getLaneDataAsInt(), false);
+ auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
+ attr.getLaneDataAsInt(), false);
return maybeLaneShape.has_value();
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2e3e40ed2d457..45fed8e548a89 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -140,7 +140,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
else
value = (Value)operandOrResult;
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(operandOrResult);
if (layout && layout.isForSubgroup()) {
if (auto inst_data = layout.getInstDataAsInt())
return inst_data.value();
@@ -204,12 +205,14 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
// skip the op if any of its operands or results has workgroup level layouts
bool hasWgLayoutOperands =
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(opr);
return layout && layout.isForWorkgroup();
});
bool hasWgLayoutResults =
llvm::any_of(op->getOpResults(), [](OpResult result) {
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(result);
return layout && layout.isForWorkgroup();
});
if (hasWgLayoutOperands || hasWgLayoutResults) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a8700ca73efc4..518c7817a516e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -470,7 +470,8 @@ struct WgToSgVectorBroadcastOp
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -535,7 +536,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
ArrayRef<int64_t> wgShape = resultType.getShape();
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -737,7 +739,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
if (!vecAttr || !vecAttr.isSplat() || !vecType)
return failure();
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -980,7 +983,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
}
}
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op->getResult(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 1d4de68754c20..cac1ffe4d3bc3 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -151,7 +151,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
return nullptr;
}
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
Operation *op = opr.getOwner();
std::string layoutName = xegpu::getLayoutName(opr);
if (op->hasAttr(layoutName))
@@ -307,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
if (!inputTy || !resultTy)
return WalkResult::skip();
- xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(input);
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(input);
if (!layout)
return WalkResult::skip();
>From f3af2c307597bf13a04579b3235b45af7ea10392 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:59:45 +0000
Subject: [PATCH 05/15] update convert_layout
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 3 +++
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 4 ++--
mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 6 +++---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 5 +++--
4 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5b4b376157c00..77e3c257f234e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -217,6 +217,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
"xegpu::DistributeLayoutAttr",
"dropSgLayoutAndData">,
+ InterfaceMethod<"Derive a new layout by dropping InstData",
+ "xegpu::DistributeLayoutAttr",
+ "dropInstData">,
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
indices based on the effective subgroup layout.}],
"FailureOr<SmallVector<Value>>",
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index ab471a1f33ef9..2f6671c5e37cc 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1162,8 +1162,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
the IR is lowered to WI level because that is the end result of all distributions.
}];
let arguments = (ins XeGPU_VectorType: $source,
- XeGPU_LayoutAttr: $input_layout,
- XeGPU_LayoutAttr: $target_layout);
+ DistributeLayoutAttr: $input_layout,
+ DistributeLayoutAttr: $target_layout);
let results = (outs XeGPU_VectorType: $result);
let assemblyFormat = [{
$source prop-dict attr-dict `:` type($source)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 45fed8e548a89..80e9d4d25b06c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -84,9 +84,9 @@ struct ConvertLayoutOpPattern
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
- xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
- if (!input_layout.getInstData() || !target_layout.getInstData())
+ xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
+ xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
+ if (!input_layout.getInstDataAsInt() || !target_layout.getInstDataAsInt())
return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
input_layout = input_layout.dropInstData();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 518c7817a516e..4fb962908793f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -613,8 +613,9 @@ struct WgToSgConvertLayoutOp
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- xegpu::LayoutAttr input = op.getInputLayout();
- xegpu::LayoutAttr target = op.getTargetLayout();
+ // TODO: currently, we only support LayoutAttr
+ auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+ auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
if (!input || !target || !input.isForWorkgroup() ||
!target.isForWorkgroup())
>From ee5baca1ccae6549aca46693814f9c8ea8b995e7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 22:54:47 +0000
Subject: [PATCH 06/15] save work
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 42 ++++++----------
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 48 +++++++++++++++++--
2 files changed, 58 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 10c2759493477..8dce63b80f373 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -639,46 +639,32 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
: outElemTyBitWidth / inElemTyBitWidth;
const LaneLayout &sourceLaneLayout =
resultLayout.getLayout(); // source lane layout is unchanged.
- ArrayRef<int64_t> currData = resultLayout.getDataAsArrayRef();
+ ArrayRef<int64_t> outData = resultLayout.getDataAsArrayRef();
// TODO: Currently we assume that bitcasts does not require cross lane
// communication. So each lane must own the required number of elements to
// perform the bitcast locally without cross-lane communication.
- // For 1D vectors, decide how many elements each lane owns based on whether
- // the bitcast is narrowing or widening.
- if (rank == 1) {
- if ((currData[0] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
- bitcast.emitWarning(
- "Narrowing bitcast with cross lane communication is not supported.");
- return;
- }
- LaneData sourceLaneData = isNarrowing
- ? LaneData({currData[0] / bitCastRatio})
- : LaneData({currData[0] * bitCastRatio});
-
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
- sourceLaneLayout, sourceLaneData)));
+ int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
+ if (outInnerBitsPerLane < inElemTyBitWidth) {
+ bitcast.emitWarning(
+ "Narrowing bitcast with cross lane communication is not supported.");
+ return;
}
- // For nD vectors, Each lane is not allowed to own multiple elements in any
- // dimension other than the innermost dimension.
- // TODO: Add support for other case depending on the use case.
- SmallVector<int64_t, 3> sourceLaneDataStorage(currData.begin(),
- currData.end() - 1);
+ // Check if each lane owns a single element in all dimensions except the
+ // innermost dimension. For example, if the result layout is [1, 16][2, 1], we
+ // are not allowed to bitcast such vectors.
+ // TODO: Relax this based on use cases.
+ SmallVector<int64_t, 3> sourceLaneDataStorage(outData.begin(),
+ outData.end() - 1);
if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
bitcast.emitWarning(
"Each lane must not own multiple elements in any dimension other than "
"the innermost dimension.");
return;
}
- // Check if the bitcast requires cross lane communication.
- if ((currData[rank - 1] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
- bitcast.emitWarning(
- "Narrowing bitcast with cross lane communication is not supported.");
- return;
- }
// Decide lane data based on whether the bitcast is narrowing or widening.
- int64_t innerMostLaneData = isNarrowing ? currData[rank - 1] / bitCastRatio
- : currData[rank - 1] * bitCastRatio;
+ int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
+ : outData[rank - 1] * bitCastRatio;
sourceLaneDataStorage.push_back(innerMostLaneData);
LaneData sourceLaneData(sourceLaneDataStorage);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4cbe4db271ad6..994fa44cab0b6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -164,9 +164,14 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
// -----
// CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
+// CHECK: %[[LOAD0:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<8x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xi16>
+// CHECK: %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xi16>
+// CHECK: %{{.*}} = vector.bitcast %[[LOAD0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: vector<8x16xi16> to vector<8x16xf16>
+// CHECK: %{{.*}} = vector.bitcast %[[LOAD1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+// CHECK-SAME: vector<16x16xi16> to vector<16x16xf16>
func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x16xi16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
@@ -183,7 +188,10 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
// -----
// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: vector<16x8xi32> to vector<16x16xf16>
func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -198,6 +206,38 @@ func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8
return
}
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_i16_to_i32(
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+// CHECK-SAME: !xegpu.tensor_desc<8x32xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xi16>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: vector<8x32xi16> to vector<8x16xi32>
+func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16xi32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi16> -> !xegpu.tensor_desc<8x32xi16>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xi16> -> vector<8x32xi16>
+ %3 = vector.bitcast %2 : vector<8x32xi16> to vector<8x16xi32>
+ xegpu.store_nd %3, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
+// CHECK-NOT: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = {{.*}}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: vector<8x16xi32> to vector<8x32xi16>
+func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+ %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x32xi16> -> !xegpu.tensor_desc<8x32xi16>
+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+ %3 = vector.bitcast %2 : vector<8x16xi32> to vector<8x32xi16>
+ xegpu.store_nd %3, %1 : vector<8x32xi16>, !xegpu.tensor_desc<8x32xi16>
+ return
+}
+
+
// -----
// CHECK-LABEL: func.func @binary_op_one_use(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
>From 621122c50d7df5adb6ed33d94b8055fdc480ecdd Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 23:14:40 +0000
Subject: [PATCH 07/15] save work
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8dce63b80f373..d8c447dd46338 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -107,7 +107,6 @@ struct LayoutInfo {
private:
LaneLayout laneLayout;
LaneData laneData;
- xegpu::LayoutAttr layoutAttr;
public:
LayoutInfo() = default;
@@ -464,7 +463,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- // Only consider 1D -> 2D broadcasts or 2D -> 2D broadcasts.
+ // Only consider vector to vector broadcasts for now.
VectorType resultTy = broadcast.getResultVectorType();
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
if (!sourceTy) {
@@ -472,7 +471,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
return;
}
- // Only conside 2D -> 2D broadcast.
+ // Only consider 2D -> 2D broadcast.
if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) {
broadcast.emitWarning("Expecting source type to be 2D vector and "
"result type to be 2D vector.");
>From 35c64895111db5d7019a64078fbe719dce317b95 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 22 Aug 2025 14:45:35 +0000
Subject: [PATCH 08/15] fix compilation error in clang
---
mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 82fd70571c022..bad734dbfd9f0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
>From b912c21cf84eee0b574f4acc8db036270d9efb36 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 25 Aug 2025 22:08:11 +0000
Subject: [PATCH 09/15] save work
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 289 +++++++++++-------
1 file changed, 173 insertions(+), 116 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index d8c447dd46338..5bba85dd4d3bc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
@@ -29,6 +30,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
@@ -36,6 +38,7 @@
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
namespace mlir {
namespace xegpu {
@@ -58,30 +61,32 @@ namespace {
/// Helper class to store the ND layout of lanes within a subgroup and data
/// owned by each lane.
-struct Layout {
- SmallVector<int64_t, 3> layout;
- Layout() = default;
- Layout(std::initializer_list<int64_t> list) : layout(list) {}
- Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
- void print(llvm::raw_ostream &os) const;
- size_t size() const { return layout.size(); }
- int64_t operator[](size_t idx) const;
-};
-
-int64_t Layout::operator[](size_t idx) const {
- assert(idx < layout.size() && "Index out of bounds");
- return layout[idx];
-}
-
-void Layout::print(llvm::raw_ostream &os) const {
- os << llvm::interleaved_array(layout);
-}
-
-/// LaneLayout represents the logical layout of lanes within a subgroup when it
-/// accesses some value. LaneData represents the logical layout of data owned by
-/// each work item.
-using LaneLayout = Layout;
-using LaneData = Layout;
+// struct Layout {
+// SmallVector<int64_t, 3> layout;
+// Layout() = default;
+// Layout(std::initializer_list<int64_t> list) : layout(list) {}
+// Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
+// void print(llvm::raw_ostream &os) const;
+// size_t size() const { return layout.size(); }
+// int64_t operator[](size_t idx) const;
+// };
+
+// int64_t Layout::operator[](size_t idx) const {
+// assert(idx < layout.size() && "Index out of bounds");
+// return layout[idx];
+// }
+
+// void Layout::print(llvm::raw_ostream &os) const {
+// os << llvm::interleaved_array(layout);
+// }
+
+// /// LaneLayout represents the logical layout of lanes within a subgroup when
+// it
+// /// accesses some value. LaneData represents the logical layout of data owned
+// by
+// /// each work item.
+// using LaneLayout = Layout;
+// using LaneData = Layout;
//===----------------------------------------------------------------------===//
// LayoutInfo
@@ -105,13 +110,14 @@ using LaneData = Layout;
struct LayoutInfo {
private:
- LaneLayout laneLayout;
- LaneData laneData;
+ mlir::Attribute storage = nullptr;
public:
LayoutInfo() = default;
- LayoutInfo(const LaneLayout &layout, const LaneData &data)
- : laneLayout(layout), laneData(data) {}
+ LayoutInfo(const xegpu::LayoutAttr &layout) : storage(layout) {}
+ LayoutInfo(const xegpu::SliceAttr &slice) : storage(slice) {
+ storage = slice.flatten();
+ }
// Two lattice values are equal if they have `some` layout. The actual
// content of the layout does not matter.
@@ -125,24 +131,44 @@ struct LayoutInfo {
void print(raw_ostream &os) const;
- bool isAssigned() const {
- return laneLayout.size() > 0 && laneData.size() > 0;
- }
+ bool isAssigned() const { return storage != nullptr; }
+
+ LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
- LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
+ ArrayRef<int> getLaneLayout() const {
+ if (!isAssigned())
+ return {};
+ if (isa<xegpu::LayoutAttr>(storage))
+ return cast<xegpu::LayoutAttr>(storage).getLaneLayout().asArrayRef();
+ xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
+ assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
+ "Slice parent must be a LayoutAttr");
+ auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
+ return parent.getLaneLayout().asArrayRef();
+ }
+ ArrayRef<int> getLaneData() const {
+ if (!isAssigned())
+ return {};
+ if (isa<xegpu::LayoutAttr>(storage))
+ return cast<xegpu::LayoutAttr>(storage).getLaneData().asArrayRef();
+ xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
+ assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
+ "Slice parent must be a LayoutAttr");
+ auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
+ return parent.getLaneData().asArrayRef();
+ }
+ bool isSliceLayout() const {
+ if (!isAssigned())
+ return false;
+ return isa<xegpu::SliceAttr>(storage);
+ }
- const LaneLayout &getLayout() const { return laneLayout; }
- const LaneData &getData() const { return laneData; }
- ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
- ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
+ Attribute get() { return storage; }
};
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
- os << "lane_layout: ";
- laneLayout.print(os);
- os << ", lane_data: ";
- laneData.print(os);
+ os << storage;
} else {
os << "Not assigned.";
}
@@ -159,18 +185,30 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
llvm_unreachable("Join should not be triggered by layout propagation.");
}
-/// Get the transposed layout according to the given permutation.
-LayoutInfo
-LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+/// Construct a new layout with the transposed lane layout and lane data.
+LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
if (!isAssigned())
return {};
- LaneLayout newLayout;
- LaneData newData;
+ // Check if the permutation is valid.
+ llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
+ bool hasDuplicates = seen.size() != permutation.size();
+ bool withinRange = llvm::all_of(permutation, [&](size_t idx) {
+ return idx >= 0 && idx < permutation.size();
+ });
+
+ if (!withinRange || hasDuplicates) {
+ assert(false && "Invalid permutation for transpose.");
+ return {};
+ }
+
+ SmallVector<int32_t> laneLayout;
+ SmallVector<int32_t> laneData;
for (int64_t idx : permutation) {
- newLayout.layout.push_back(laneLayout.layout[idx]);
- newData.layout.push_back(laneData.layout[idx]);
+ laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+ laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
}
- return LayoutInfo(newLayout, newData);
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
}
//===----------------------------------------------------------------------===//
@@ -190,13 +228,15 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// Helper Function to get the default layout for uniform values like constants.
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+ unsigned rank) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
- if (rank == 1)
- return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
- LaneData({1}));
- return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
- LaneData({1, 1}));
+ if (rank == 1) {
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+ }
+ return LayoutInfo(xegpu::LayoutAttr::get(
+ ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
}
/// Helper to get the default layout for a vector type.
@@ -209,14 +249,15 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(1);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
// Packing factor is determined by the element type bitwidth.
int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
- return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
- LaneData({1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+ {1, xegpu::targetinfo::subgroupSize},
+ {1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
@@ -229,7 +270,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(1);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
@@ -238,16 +279,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
: 1;
- return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
- LaneData({1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(
+ tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
+ {1, packingFactor}));
}
int packingFactor =
(bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
: 1;
- return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
- LaneData({1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(),
+ {1, xegpu::targetinfo::subgroupSize},
+ {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -261,15 +304,17 @@ static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
- LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
+ SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize});
// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
// must have the VNNI format.
if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
xegpu::targetinfo::packedSizeInBitsForDpasB) {
- LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
- elementTy.getIntOrFloatBitWidth(),
- 1});
- return LayoutInfo(layout, data);
+ SmallVector<int32_t, 2> data(
+ {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB /
+ elementTy.getIntOrFloatBitWidth()),
+ 1});
+ return LayoutInfo(
+ xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
return getDefaultSIMTLayoutInfo(vectorTy);
@@ -450,7 +495,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
}
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
- LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
+ LayoutInfo operandLayout =
+ getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -494,43 +540,55 @@ void LayoutInfoPropagation::visitShapeCastOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- VectorType sourceTy = shapeCast.getSourceVectorType();
- VectorType resultTy = shapeCast.getResultVectorType();
+ int64_t sourceRank = shapeCast.getSourceVectorType().getRank();
+ int64_t resultRank = shapeCast.getResultVectorType().getRank();
// Expecting source rank to be 1D or 2D.
- if (sourceTy.getRank() != 1 && sourceTy.getRank() != 2) {
+ if (sourceRank != 1 && sourceRank != 2) {
shapeCast.emitWarning("Expecting source type to be 1D or 2D vector.");
return;
}
// Expecting result rank to be 1D or 2D.
- if (resultTy.getRank() != 1 && resultTy.getRank() != 2) {
+ if (resultRank != 1 && resultRank != 2) {
shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
return;
}
// For 2D -> 2D shape cast, propagate the result layout to the source.
- if (sourceTy.getRank() == 2 && resultTy.getRank() == 2) {
- // Propagate the result layout to the source operand.
+ if (sourceRank == 2 && resultRank == 2) {
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
return;
}
- auto resultLayoutArray = resultLayout.getLayoutAsArrayRef();
- if (resultLayoutArray[0] != 1 && resultLayoutArray[1] != 1) {
+ auto resultLaneLayout = resultLayout.getLaneLayout();
+ if (resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
shapeCast.emitWarning(
"Expecting result layout to be of form [1, subgroupSize] "
"or [subgroupSize, 1].");
return;
}
- int64_t distributedDim = resultLayoutArray[0] == 1 ? 1 : 0;
- // If the result shape can be evenly distributed in the distributed dimension,
- // then the source layout should be [subgroupSize][1]. Otherwise, data is
- // shared accross lanes (broadcasted). In that case, just assign [1][1] for
- // now (TODO: Use slice for this case)
- LayoutInfo sourceLayout =
- resultTy.getShape()[distributedDim] % xegpu::targetinfo::subgroupSize == 0
- ? LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
- LaneData({1}))
- : LayoutInfo(LaneLayout({1}), LaneData({1}));
- // Propagate the source layout to the source operand.
- propagateIfChanged(operands[0], operands[0]->meet(sourceLayout));
+ ArrayRef<int64_t> resultShape = shapeCast.getResultVectorType().getShape();
+ // For 2D -> 1D case, source gets the reusult's lane layout and lane data.
+ if (sourceRank == 2 && resultRank == 1) {
+ propagateIfChanged(operands[0],
+ operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+ shapeCast->getContext(), resultLaneLayout,
+ resultLayout.getLaneData()))));
+ return;
+ }
+
+ // For 1D -> 2D case, If the result shape can be evenly distributed in the
+ // distributed dimension, then the source layout should be [subgroupSize][1].
+ // Otherwise, data is shared accross lanes (broadcasted). We use slice
+ // attribute for the broadcast case.
+ int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
+ xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
+ shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
+ if (resultShape[distributedDim] % xegpu::targetinfo::subgroupSize != 0) {
+ xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+ shapeCast->getContext(), plainLayout,
+ DenseI64ArrayAttr::get(shapeCast->getContext(), {distributedDim}));
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
+ return;
+ }
+ propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
}
/// Propagate the layout of the result tensor to the source tensor descriptor in
@@ -591,7 +649,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
if (auto transpose = load.getTranspose()) {
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
"LayoutInfoPropagation stage.");
- tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+ tensorDescLayout = valueLayout.transpose(transpose.value());
}
// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -606,8 +664,7 @@ void LayoutInfoPropagation::visitTransposeOp(
LayoutInfo resultLayout = results[0]->getValue();
if (!resultLayout.isAssigned())
return;
- LayoutInfo newLayout =
- resultLayout.getTransposedLayout(transpose.getPermutation());
+ LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
// Propagate the new layout to the vector operand.
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
}
@@ -636,9 +693,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
: outElemTyBitWidth / inElemTyBitWidth;
- const LaneLayout &sourceLaneLayout =
- resultLayout.getLayout(); // source lane layout is unchanged.
- ArrayRef<int64_t> outData = resultLayout.getDataAsArrayRef();
+ ArrayRef<int> sourceLaneLayout =
+ resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
+ ArrayRef<int> outData = resultLayout.getLaneData();
// TODO: Currently we assume that bitcasts does not require cross lane
// communication. So each lane must own the required number of elements to
@@ -650,12 +707,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
return;
}
// Check if each lane owns a single element in all dimensions except the
- // innermost dimension. For example, if the result layout is [1, 16][2, 1], we
- // are not allowed to bitcast such vectors.
- // TODO: Relax this based on use cases.
- SmallVector<int64_t, 3> sourceLaneDataStorage(outData.begin(),
- outData.end() - 1);
- if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
+ // innermost dimension.
+ SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
+ if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
bitcast.emitWarning(
"Each lane must not own multiple elements in any dimension other than "
"the innermost dimension.");
@@ -664,11 +718,12 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
// Decide lane data based on whether the bitcast is narrowing or widening.
int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
: outData[rank - 1] * bitCastRatio;
- sourceLaneDataStorage.push_back(innerMostLaneData);
- LaneData sourceLaneData(sourceLaneDataStorage);
+ sourceLaneData.push_back(innerMostLaneData);
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
- sourceLaneLayout, sourceLaneData)));
+ propagateIfChanged(
+ operands[0],
+ operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+ bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
}
/// Propagate the layout of the result to the tensor descriptor and mask
@@ -680,7 +735,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
// Mask operand should have 1D default layout.
- LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+ LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
// Propagate the new layout to the tensor descriptor operand.
propagateIfChanged(operands[0], operands[0]->meet(layout));
@@ -698,7 +753,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
if (!descLayout.isAssigned())
return;
// For offset operand propagate 1D default layout.
- LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
@@ -725,7 +780,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
// Propagate the tensor descriptor layout.
propagateIfChanged(operands[1], operands[1]->meet(layout));
// Use default 1D layout for mask operand.
- LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+ LayoutInfo maskLayout =
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
}
@@ -813,7 +869,7 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
printFunctionResult(funcOp);
}
-using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+using GetLayoutFnTy = function_ref<xegpu::LayoutTrait(Value)>;
/// Update an operation with the layout of its results. If the result type is a
/// vector type, a temporary layout attribute is added to the operation. If the
/// result type is a tensor descriptor type, the type is updated with the layout
@@ -832,7 +888,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
if (!isa<VectorType, xegpu::TensorDescType>(resultType))
continue;
// If the result has no layout but has users, emit a warning and continue.
- xegpu::LayoutAttr layout = getLayoutOfValue(result);
+ xegpu::LayoutTrait layout = getLayoutOfValue(result);
if (!layout && result.getNumUses() > 0) {
op->emitWarning("op has users but no layout assigned for its result");
continue;
@@ -898,8 +954,9 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// We only need to operate on tensor descriptor or vector types.
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
continue;
- xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
- xegpu::LayoutAttr successorOperandLayout =
+ xegpu::LayoutTrait successorInputLayout =
+ getLayoutOfValue(successorInput);
+ xegpu::LayoutTrait successorOperandLayout =
getLayoutOfValue(successorOperand);
// If either of the layouts is not assigned, we cannot proceed.
@@ -947,7 +1004,7 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
newArgTypes.push_back(argType);
if (!isa<VectorType, xegpu::TensorDescType>(argType))
continue;
- xegpu::LayoutAttr layout = getLayoutOfValue(arg);
+ xegpu::LayoutTrait layout = getLayoutOfValue(arg);
if (!layout) {
LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
<< " but got none.\n");
@@ -989,13 +1046,13 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return;
}
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
- auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+ auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutTrait {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
- return xegpu::LayoutAttr::get(
- val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
- llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
+ if (layout.isSliceLayout())
+ return cast<xegpu::SliceAttr>(layout.get());
+ return cast<xegpu::LayoutAttr>(layout.get());
};
mlir::OpBuilder builder(&getContext());
>From 5a683b443e3160c2c81449338beeedebfe6ac229 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 25 Aug 2025 23:53:39 +0000
Subject: [PATCH 10/15] save work
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 142 ++++++++++--------
.../Transforms/XeGPUSubgroupDistribute.cpp | 26 ++--
2 files changed, 94 insertions(+), 74 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 9a7c9570af6b6..0434566e21f4e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -110,14 +110,11 @@ namespace {
struct LayoutInfo {
private:
- mlir::Attribute storage = nullptr;
+ xegpu::DistributeLayoutAttr storage = nullptr;
public:
LayoutInfo() = default;
- LayoutInfo(const xegpu::LayoutAttr &layout) : storage(layout) {}
- LayoutInfo(const xegpu::SliceAttr &slice) : storage(slice) {
- storage = slice.flatten();
- }
+ LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
// Two lattice values are equal if they have `some` layout. The actual
// content of the layout does not matter.
@@ -135,28 +132,26 @@ struct LayoutInfo {
LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
- ArrayRef<int> getLaneLayout() const {
+ SmallVector<int> getLaneLayout() const {
if (!isAssigned())
return {};
- if (isa<xegpu::LayoutAttr>(storage))
- return cast<xegpu::LayoutAttr>(storage).getLaneLayout().asArrayRef();
- xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
- assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
- "Slice parent must be a LayoutAttr");
- auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
- return parent.getLaneLayout().asArrayRef();
+ assert(storage.getLaneLayoutAsInt().has_value() &&
+ "Expected lane layout to be assigned");
+ return llvm::map_to_vector(
+ storage.getLaneLayoutAsInt().value(),
+ [](int64_t val) { return static_cast<int>(val); });
}
- ArrayRef<int> getLaneData() const {
+
+ SmallVector<int> getLaneData() const {
if (!isAssigned())
return {};
- if (isa<xegpu::LayoutAttr>(storage))
- return cast<xegpu::LayoutAttr>(storage).getLaneData().asArrayRef();
- xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
- assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
- "Slice parent must be a LayoutAttr");
- auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
- return parent.getLaneData().asArrayRef();
+ assert(storage.getLaneDataAsInt().has_value() &&
+ "Expected lane data to be assigned");
+ return llvm::map_to_vector(
+ storage.getLaneDataAsInt().value(),
+ [](int64_t val) { return static_cast<int>(val); });
}
+
bool isSliceLayout() const {
if (!isAssigned())
return false;
@@ -558,26 +553,49 @@ void LayoutInfoPropagation::visitShapeCastOp(
return;
}
auto resultLaneLayout = resultLayout.getLaneLayout();
- if (resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
+ if (resultRank == 2 && resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
shapeCast.emitWarning(
- "Expecting result layout to be of form [1, subgroupSize] "
+ "Expecting 2D result layout to be of form [1, subgroupSize] "
"or [subgroupSize, 1].");
return;
}
ArrayRef<int64_t> resultShape = shapeCast.getResultVectorType().getShape();
- // For 2D -> 1D case, source gets the reusult's lane layout and lane data.
+ ArrayRef<int64_t> sourceShape = shapeCast.getSourceVectorType().getShape();
+ // For 2D -> 1D case.
if (sourceRank == 2 && resultRank == 1) {
- propagateIfChanged(operands[0],
- operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
- shapeCast->getContext(), resultLaneLayout,
- resultLayout.getLaneData()))));
- return;
+ // If the result had slice layout, simply assign the parent layout of the
+ // slice.
+ if (resultLayout.isSliceLayout()) {
+ auto sliceAttr = cast<xegpu::SliceAttr>(resultLayout.get());
+ propagateIfChanged(operands[0],
+ operands[0]->meet(LayoutInfo(sliceAttr.getParent())));
+ return;
+ }
+ // If the result has a regular 1D layout, then we find the first dimension
+ // that can be fully evenly distributed to lanes. This dimension becomes
+ // the distributed dimension for deciding the lane layout.
+ int sourceDistributedDim =
+ sourceShape[0] % xegpu::targetinfo::subgroupSize == 0
+ ? 0
+ : (sourceShape[1] % xegpu::targetinfo::subgroupSize ? 1 : -1);
+ if (sourceDistributedDim == -1) {
+ shapeCast.emitWarning(
+ "Source vector can not be evenly distributed across lanes.");
+ return;
+ }
+ SmallVector<int> sourceLaneLayout = {1, 1},
+ laneData = {1, resultLayout.getLaneData()[0]};
+ sourceLaneLayout[sourceDistributedDim] = xegpu::targetinfo::subgroupSize;
+ propagateIfChanged(
+ operands[0],
+ operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+ shapeCast->getContext(), sourceLaneLayout, laneData))));
}
// For 1D -> 2D case, If the result shape can be evenly distributed in the
- // distributed dimension, then the source layout should be [subgroupSize][1].
- // Otherwise, data is shared accross lanes (broadcasted). We use slice
- // attribute for the broadcast case.
+ // distributed dimension, then the source layout should be
+ // [subgroupSize][1]. Otherwise, data is shared accross lanes (broadcasted).
+ // We use slice attribute for the broadcast case.
int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
@@ -591,8 +609,8 @@ void LayoutInfoPropagation::visitShapeCastOp(
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
}
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
+/// Propagate the layout of the result tensor to the source tensor descriptor
+/// in UpdateNdOffsetOp.
void LayoutInfoPropagation::visitUpdateNdOffsetOp(
xegpu::UpdateNdOffsetOp updateNdOffset,
ArrayRef<LayoutInfoLattice *> operands,
@@ -710,9 +728,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
// innermost dimension.
SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
- bitcast.emitWarning(
- "Each lane must not own multiple elements in any dimension other than "
- "the innermost dimension.");
+ bitcast.emitWarning("Each lane must not own multiple elements in any "
+ "dimension other than "
+ "the innermost dimension.");
return;
}
// Decide lane data based on whether the bitcast is narrowing or widening.
@@ -869,15 +887,16 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
printFunctionResult(funcOp);
}
-using GetLayoutFnTy = function_ref<xegpu::LayoutTrait(Value)>;
-/// Update an operation with the layout of its results. If the result type is a
-/// vector type, a temporary layout attribute is added to the operation. If the
-/// result type is a tensor descriptor type, the type is updated with the layout
-/// attribute. The users of the result are also updated with the layout
+using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
+/// Update an operation with the layout of its results. If the result type is
+/// a vector type, a temporary layout attribute is added to the operation. If
+/// the result type is a tensor descriptor type, the type is updated with the
+/// layout attribute. The users of the result are also updated with the layout
/// attribute.
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
GetLayoutFnTy getLayoutOfValue) {
- // Region ops (like scf.for) are already handled by the updateControlFlowOps.
+ // Region ops (like scf.for) are already handled by the
+ // updateControlFlowOps.
if (mlir::isa<mlir::RegionBranchOpInterface>(op))
return success();
@@ -888,7 +907,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
if (!isa<VectorType, xegpu::TensorDescType>(resultType))
continue;
// If the result has no layout but has users, emit a warning and continue.
- xegpu::LayoutTrait layout = getLayoutOfValue(result);
+ xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
if (!layout && result.getNumUses() > 0) {
op->emitWarning("op has users but no layout assigned for its result");
continue;
@@ -910,14 +929,14 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
}
/// Region ops like scf.for need special handling because they have blocks
-/// inside. If the blocks have tensor descriptor type as block arguments, thier
-/// types must be updated. Also region op can have results that may not have any
-/// users (e.g. A and B tiles). They are not assigned a layout by layout
-/// analysis because they have no users. However inside the region op
-/// corresponding block arguments for these results do have layouts. Therefore,
-/// in this case we still need to update the result types with the layout
-/// attribute. This function function updates the internal block arguments and
-/// the result types of the region op with the assigned layouts.
+/// inside. If the blocks have tensor descriptor type as block arguments,
+/// thier types must be updated. Also region op can have results that may not
+/// have any users (e.g. A and B tiles). They are not assigned a layout by
+/// layout analysis because they have no users. However inside the region op
+/// corresponding block arguments for these results do have layouts.
+/// Therefore, in this case we still need to update the result types with the
+/// layout attribute. This function function updates the internal block
+/// arguments and the result types of the region op with the assigned layouts.
/// clang-format off
/// Example: scf.for ... iter_args(...) -> (out types) {
/// ^bb0(block types):
@@ -929,8 +948,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
/// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
/// itself (yield the results). So we update both the block arguments of the
/// successor region (i.e. block types) and the result types of the scf.for op
-/// (i.e. out types). Note that yield types are updated by respective producers
-/// inside bb0.
+/// (i.e. out types). Note that yield types are updated by respective
+/// producers inside bb0.
static LogicalResult
updateControlFlowOps(mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
@@ -954,17 +973,16 @@ updateControlFlowOps(mlir::OpBuilder &builder,
// We only need to operate on tensor descriptor or vector types.
if (!isa<xegpu::TensorDescType, VectorType>(inputType))
continue;
- xegpu::LayoutTrait successorInputLayout =
+ xegpu::DistributeLayoutAttr successorInputLayout =
getLayoutOfValue(successorInput);
- xegpu::LayoutTrait successorOperandLayout =
+ xegpu::DistributeLayoutAttr successorOperandLayout =
getLayoutOfValue(successorOperand);
// If either of the layouts is not assigned, we cannot proceed.
if (!successorOperandLayout) {
- LLVM_DEBUG(
- DBGS()
- << "No layout assigned for forwarded operand in branch terminator: "
- << successorOperand << "\n");
+ LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
+ "branch terminator: "
+ << successorOperand << "\n");
return failure();
}
// We expect the layouts to match.
@@ -1004,7 +1022,7 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
newArgTypes.push_back(argType);
if (!isa<VectorType, xegpu::TensorDescType>(argType))
continue;
- xegpu::LayoutTrait layout = getLayoutOfValue(arg);
+ xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
if (!layout) {
LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
<< " but got none.\n");
@@ -1046,7 +1064,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
return;
}
// Helper to convert LayoutInfo to xegpu::LayoutAttr.
- auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutTrait {
+ auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 27b8fc1c2919d..31821ee07d418 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -76,12 +76,12 @@ namespace {
/// | 32x16 | [2, 8] | 16x2 |
/// | 2x32x16 | [1, 16] | 2x32x1 |
static FailureOr<VectorType>
-getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
+getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
VectorType originalType) {
if (!layout)
return failure();
- auto laneLayout = layout.getLaneLayout().asArrayRef();
+ auto laneLayout = layout.getLaneLayoutAsInt().value();
assert(originalType.getShape().size() >= laneLayout.size() &&
"Rank of the original vector type should be greater or equal to the "
"size of the lane layout to distribute the vector type.");
@@ -868,7 +868,7 @@ struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
unsigned operandIdx = operand->getOperandNumber();
VectorType distributedSourceType =
getDistVecTypeBasedOnLaneLayout(
- xegpu::getLayoutAttr(bitcastOp.getSource()),
+ xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
bitcastOp.getSourceVectorType())
.value_or(VectorType());
if (!distributedSourceType)
@@ -907,24 +907,26 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
warpOp, "warp result is not a vector::Transpose op");
auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
unsigned operandIdx = operand->getOperandNumber();
- xegpu::LayoutAttr sourceLayout =
- xegpu::getLayoutAttr(transposeOp.getVector());
- xegpu::LayoutAttr resultLayout =
- xegpu::getLayoutAttr(transposeOp.getResult());
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getVector());
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getDistributeLayoutAttr(transposeOp.getResult());
if (!sourceLayout || !resultLayout)
return rewriter.notifyMatchFailure(
transposeOp,
"the source or result vector of the transpose op lacks layout "
"attribute");
- ArrayRef<int> sourceLaneLayout = sourceLayout.getLaneLayout().asArrayRef();
- ArrayRef<int> resultLaneLayout = resultLayout.getLaneLayout().asArrayRef();
- ArrayRef<int> sourceLaneData = sourceLayout.getLaneData().asArrayRef();
- ArrayRef<int> resultLaneData = resultLayout.getLaneData().asArrayRef();
+ ArrayRef<int64_t> sourceLaneLayout =
+ sourceLayout.getLaneLayoutAsInt().value();
+ ArrayRef<int64_t> resultLaneLayout =
+ resultLayout.getLaneLayoutAsInt().value();
+ ArrayRef<int64_t> sourceLaneData = sourceLayout.getLaneDataAsInt().value();
+ ArrayRef<int64_t> resultLaneData = resultLayout.getLaneDataAsInt().value();
if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
return rewriter.notifyMatchFailure(
transposeOp, "the source or result vector of the transpose op "
"does not have 2D layout");
- auto is2DTranspose = [](ArrayRef<int> input, ArrayRef<int> output) {
+ auto is2DTranspose = [](ArrayRef<int64_t> input, ArrayRef<int64_t> output) {
return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
input[1] == output[0];
};
>From 2da2c6de6f3043462d871b8083a19e09738cc509 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 19:54:35 +0000
Subject: [PATCH 11/15] save work
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 19 ++---
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 80 ++++++++++++++++++-
2 files changed, 88 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 0434566e21f4e..3f30751875679 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -187,8 +187,8 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
// Check if the permutation is valid.
llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
bool hasDuplicates = seen.size() != permutation.size();
- bool withinRange = llvm::all_of(permutation, [&](size_t idx) {
- return idx >= 0 && idx < permutation.size();
+ bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
+ return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
});
if (!withinRange || hasDuplicates) {
@@ -577,7 +577,7 @@ void LayoutInfoPropagation::visitShapeCastOp(
int sourceDistributedDim =
sourceShape[0] % xegpu::targetinfo::subgroupSize == 0
? 0
- : (sourceShape[1] % xegpu::targetinfo::subgroupSize ? 1 : -1);
+ : (sourceShape[1] % xegpu::targetinfo::subgroupSize == 0 ? 1 : -1);
if (sourceDistributedDim == -1) {
shapeCast.emitWarning(
"Source vector can not be evenly distributed across lanes.");
@@ -597,16 +597,17 @@ void LayoutInfoPropagation::visitShapeCastOp(
// [subgroupSize][1]. Otherwise, data is shared accross lanes (broadcasted).
// We use slice attribute for the broadcast case.
int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
- xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
- shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
if (resultShape[distributedDim] % xegpu::targetinfo::subgroupSize != 0) {
+ xegpu::LayoutAttr parentLayout = xegpu::LayoutAttr::get(
+ shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
- shapeCast->getContext(), plainLayout,
+ shapeCast->getContext(), parentLayout,
DenseI64ArrayAttr::get(shapeCast->getContext(), {distributedDim}));
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
return;
}
- propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
+ propagateIfChanged(operands[0], operands[0]->meet(getDefaultSIMTLayoutInfo(
+ shapeCast.getSourceVectorType())));
}
/// Propagate the layout of the result tensor to the source tensor descriptor
@@ -711,9 +712,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
: outElemTyBitWidth / inElemTyBitWidth;
- ArrayRef<int> sourceLaneLayout =
+ SmallVector<int> sourceLaneLayout =
resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
- ArrayRef<int> outData = resultLayout.getLaneData();
+ SmallVector<int> outData = resultLayout.getLaneData();
// TODO: Currently we assume that bitcasts does not require cross lane
// communication. So each lane must own the required number of elements to
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 994fa44cab0b6..25d237c58e2ce 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -455,7 +455,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
}
// -----
-// CHECK-LABEL: func.func @test_scf_while_and_condition(
+// CHECK-LABEL: func.func @scf_while_and_condition(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
// CHECK: %{{.*}}:3 = scf.while ({{.*}}) : (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>)
// CHECK-SAME: -> (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
@@ -464,7 +464,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
// CHECK-NEXT: ^bb0(%{{.*}}: vector<16xf32>, %{{.*}}: i32, %{{.*}}: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>):
// CHECK: scf.yield {{.*}} : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
// CHECK-NEXT: } attributes {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
+func.func @scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
%c0 = arith.constant 0 : i32
%c16 = arith.constant 16 : i32
%c256 = arith.constant 256 : i32
@@ -486,3 +486,79 @@ func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<25
}
return
}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim0_distributed(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x1xf16>
+// CHECK-NEXT: %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME: : vector<16x1xf16> to vector<16xf16>
+func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
+ %c0 = arith.constant 0 : index
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16>
+ %2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16>
+ xegpu.store_nd %2, %arg1 : vector<16xf16>, !xegpu.tensor_desc<16xf16>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim1_distributed(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<1x16xf16>
+// CHECK: %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME: vector<1x16xf16> to vector<16xf16>
+func.func @vector_shape_cast_2d_to_1d_dim1_distributed(%arg0: !xegpu.tensor_desc<1x16xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
+ %c0 = arith.constant 0 : index
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16>
+ %2 = vector.shape_cast %3 : vector<1x16xf16> to vector<16xf16>
+ xegpu.store_nd %2, %arg1 : vector<16xf16>, !xegpu.tensor_desc<16xf16>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: vector<16xf16> to vector<1x16xf16>
+func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+ %2 = vector.shape_cast %4 : vector<16xf16> to vector<1x16xf16>
+ %5 = vector.broadcast %2 : vector<1x16xf16> to vector<16x16xf16>
+ xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1]
+// CHECK-SAME: vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME: vector<16xf16> to vector<16x1xf16>
+func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<0.0000> : vector<16xf16>
+ %3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+ %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+ %2 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+ %5 = vector.broadcast %2 : vector<16x1xf16> to vector<16x16xf16>
+ xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+ return
+}
>From 7eabad47a70eaac1c15207a62d844f01c4205b62 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:10:04 +0000
Subject: [PATCH 12/15] save work
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 3 +
.../Dialect/XeGPU/subgroup-distribute.mlir | 96 +++++++++++++++++++
2 files changed, 99 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 31821ee07d418..3e67e6406b956 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -827,6 +827,9 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
+/// outside of the warp op.
struct MemrefExtractAlignedPointerAsIndexDistribution final
: public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..690b13f5a2973 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,99 @@ gpu.module @test {
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
+// CHECK: %{{.*}} = memref.extract_aligned_pointer_as_index %{{.*}} : memref<256x256xf16> -> index
+gpu.module @test {
+ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf16>
+ %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+ %ptr_i64 = arith.index_cast %ptr : index to i64
+ %tdesc = xegpu.create_nd_tdesc %ptr_i64[%c0], shape: [16], strides: [16] : i64
+ -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ xegpu.store_nd %cst, %tdesc : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+ gpu.return
+ }
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @vector_transpose(
+// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2xf32>
+// CHECK: %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<2x16xf32> -> !xegpu.tensor_desc<2x16xf32>
+// CHECK: xegpu.store_nd %[[CST]], %[[DEST]] : vector<2xf32>, !xegpu.tensor_desc<2x16xf32>
+gpu.module @test {
+ gpu.func @vector_transpose(%arg0: memref<2x16xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} dense<1.000000e+00>
+ : vector<16x2xf32>
+ %c0 = arith.constant 0 : index
+ %transpose = vector.transpose %cst, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<16x2xf32> to vector<2x16xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<2x16xf32>
+ -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %transpose, %0 : vector<2x16xf32>,
+ !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_bitcast(
+// CHECK: %[[CAST:.*]] = vector.bitcast %{{.*}} : vector<4x2xi8> to vector<4x1xi16>
+// CHECK-NEXT: %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x16xi16> -> !xegpu.tensor_desc<4x16xi16>
+// CHECK-NEXT: %[[T0:.*]] = vector.shape_cast %[[CAST]] : vector<4x1xi16> to vector<4xi16>
+// CHECK-NEXT: xegpu.store_nd %[[T0]], %[[DEST]] : vector<4xi16>, !xegpu.tensor_desc<4x16xi16>
+gpu.module @test {
+ gpu.func @vector_bitcast(%arg0: memref<4x16xi16>) {
+ %cst = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+ : () -> (vector<4x32xi8>)
+ %bitcast = vector.bitcast %cst {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<4x32xi8> to vector<4x16xi16>
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x16xi16>
+ -> !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %bitcast, %0 : vector<4x16xi16>,
+ !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @mma_transpose_b(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT: %[[A:.*]] = xegpu.load_nd %[[ADESC]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-NEXT: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
+// CHECK-NEXT: %[[BCAST0:.*]] = vector.shape_cast %[[B]] : vector<8xi32> to vector<1x8xi32>
+// CHECK-NEXT: %[[BCAST1:.*]] = vector.bitcast %[[BCAST0]] : vector<1x8xi32> to vector<1x16xf16>
+// CHECK-NEXT: %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16>
+// CHECK-NEXT: %[[C:.*]] = xegpu.dpas %[[A]], %[[BCAST2]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+gpu.module @test {
+ gpu.func @mma_transpose_b(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16>
+ -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+ %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x8xi32>
+ -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+ %3 = xegpu.load_nd %2 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+ : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+ %4 = vector.bitcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+ : vector<16x8xi32> to vector<16x16xf16>
+ %5 = vector.transpose %4, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+ : vector<16x16xf16> to vector<16x16xf16>
+ %6 = xegpu.dpas %1, %5 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32>
+ -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ xegpu.store_nd %6, %7 : vector<8x16xf32>,
+ !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+
+ }
+}
>From 635a00679d1287f23a18594d8643811bbc6297f5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:20:54 +0000
Subject: [PATCH 13/15] save work
---
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 25d237c58e2ce..29592ec76f918 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -224,7 +224,7 @@ func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16
// -----
// CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
-// CHECK-NOT: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = {{.*}}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+// CHECK: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
// CHECK: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
// CHECK-SAME: vector<8x16xi32> to vector<8x32xi16>
func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
>From 74ab5a37ee0acee4d564c5eecb1fb0b564a5157b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:38:59 +0000
Subject: [PATCH 14/15] save work
---
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 690b13f5a2973..8ecd080c96922 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -382,10 +382,10 @@ gpu.module @test {
// CHECK-LABEL: gpu.func @mma_transpose_b(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT: %[[A:.*]] = xegpu.load_nd %[[ADESC]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK-NEXT: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
-// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
+// CHECK-DAG: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-DAG: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[ADESC]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
// CHECK-NEXT: %[[BCAST0:.*]] = vector.shape_cast %[[B]] : vector<8xi32> to vector<1x8xi32>
// CHECK-NEXT: %[[BCAST1:.*]] = vector.bitcast %[[BCAST0]] : vector<1x8xi32> to vector<1x16xf16>
// CHECK-NEXT: %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16>
>From b36e109eb628a9262000ddfe4eb5e9c1e0d9bc5b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 27 Aug 2025 00:00:59 +0000
Subject: [PATCH 15/15] save work
---
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8ecd080c96922..d2af6d064bb03 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -380,8 +380,7 @@ gpu.module @test {
// -----
// CHECK-LABEL: gpu.func @mma_transpose_b(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
-// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
// CHECK-DAG: %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
// CHECK-DAG: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[ADESC]] : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
More information about the Mlir-commits
mailing list