[Mlir-commits] [mlir] [mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (PR #155517)

Charitha Saumya llvmlistbot at llvm.org
Tue Aug 26 17:01:23 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/155517

>From 28c5c4c5f29a23dee72e9397e0f93063dc167e75 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 16:11:50 +0000
Subject: [PATCH 01/15] pull changes

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 159 ++++++++++++++-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 188 +++++++++++++++++-
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  17 ++
 3 files changed, 353 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef88042fc663..10c2759493477 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -62,10 +62,17 @@ struct Layout {
   SmallVector<int64_t, 3> layout;
   Layout() = default;
   Layout(std::initializer_list<int64_t> list) : layout(list) {}
+  Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
   void print(llvm::raw_ostream &os) const;
   size_t size() const { return layout.size(); }
+  int64_t operator[](size_t idx) const;
 };
 
+int64_t Layout::operator[](size_t idx) const {
+  assert(idx < layout.size() && "Index out of bounds");
+  return layout[idx];
+}
+
 void Layout::print(llvm::raw_ostream &os) const {
   os << llvm::interleaved_array(layout);
 }
@@ -324,6 +331,13 @@ class LayoutInfoPropagation
                                    ArrayRef<LayoutInfoLattice *> operands,
                                    ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitVectorBroadCastOp(vector::BroadcastOp broadcast,
+                              ArrayRef<LayoutInfoLattice *> operands,
+                              ArrayRef<const LayoutInfoLattice *> results);
+  void visitShapeCastOp(vector::ShapeCastOp shapeCast,
+                        ArrayRef<LayoutInfoLattice *> operands,
+                        ArrayRef<const LayoutInfoLattice *> results);
+
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable)
@@ -383,6 +397,12 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::MultiDimReductionOp>([&](auto reductionOp) {
         visitVectorMultiReductionOp(reductionOp, operands, results);
       })
+      .Case<vector::BroadcastOp>([&](auto broadcastOp) {
+        visitVectorBroadCastOp(broadcastOp, operands, results);
+      })
+      .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
+        visitShapeCastOp(shapeCastOp, operands, results);
+      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -437,6 +457,83 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
 }
 
+void LayoutInfoPropagation::visitVectorBroadCastOp(
+    vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  // Only consider 1D -> 2D broadcasts or 2D -> 2D broadcasts.
+  VectorType resultTy = broadcast.getResultVectorType();
+  VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
+  if (!sourceTy) {
+    broadcast.emitWarning("Expecting source type to be a vector type.");
+    return;
+  }
+
+  // Only conside 2D -> 2D broadcast.
+  if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) {
+    broadcast.emitWarning("Expecting source type to be 2D vector and "
+                          "result type to be 2D vector.");
+    return;
+  }
+  SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
+  if (broadcastUnitDims.size() != 1) {
+    broadcast.emitWarning("Expecting source type to be 2D vector only with "
+                          "one broadcasted dimension.");
+    return;
+  }
+  // Propagate the result layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+}
+
+void LayoutInfoPropagation::visitShapeCastOp(
+    vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resultLayout = results[0]->getValue();
+  if (!resultLayout.isAssigned())
+    return;
+  VectorType sourceTy = shapeCast.getSourceVectorType();
+  VectorType resultTy = shapeCast.getResultVectorType();
+  // Expecting source rank to be 1D or 2D.
+  if (sourceTy.getRank() != 1 && sourceTy.getRank() != 2) {
+    shapeCast.emitWarning("Expecting source type to be 1D or 2D vector.");
+    return;
+  }
+  // Expecting result rank to be 1D or 2D.
+  if (resultTy.getRank() != 1 && resultTy.getRank() != 2) {
+    shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
+    return;
+  }
+  // For 2D -> 2D shape cast, propagate the result layout to the source.
+  if (sourceTy.getRank() == 2 && resultTy.getRank() == 2) {
+    // Propagate the result layout to the source operand.
+    propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+    return;
+  }
+  auto resultLayoutArray = resultLayout.getLayoutAsArrayRef();
+  if (resultLayoutArray[0] != 1 && resultLayoutArray[1] != 1) {
+    shapeCast.emitWarning(
+        "Expecting result layout to be of form [1, subgroupSize] "
+        "or [subgroupSize, 1].");
+    return;
+  }
+  int64_t distributedDim = resultLayoutArray[0] == 1 ? 1 : 0;
+  // If the result shape can be evenly distributed in the distributed dimension,
+  // then the source layout should be [subgroupSize][1]. Otherwise, data is
+  // shared accross lanes (broadcasted). In that case, just assign [1][1] for
+  // now (TODO: Use slice for this case)
+  LayoutInfo sourceLayout =
+      resultTy.getShape()[distributedDim] % xegpu::targetinfo::subgroupSize == 0
+          ? LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
+                       LaneData({1}))
+          : LayoutInfo(LaneLayout({1}), LaneData({1}));
+  // Propagate the source layout to the source operand.
+  propagateIfChanged(operands[0], operands[0]->meet(sourceLayout));
+}
+
 /// Propagate the layout of the result tensor to the source tensor descriptor in
 /// UpdateNdOffsetOp.
 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
@@ -529,16 +626,64 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
       bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
   int outElemTyBitWidth =
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-
-  // NOTE: We do not expect widening or narrowing bitcasts at this stage. Emit
-  // a warning and return.
-  if (inElemTyBitWidth != outElemTyBitWidth) {
-    bitcast.emitWarning("Widening or narrowing bitcasts are not expected at "
-                        "layout propagation stage.");
+  // If the element bit widths are the same, then the layout does not change.
+  if (inElemTyBitWidth == outElemTyBitWidth) {
+    propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
     return;
   }
+  int64_t rank = bitcast.getSourceVectorType().getRank();
+  // Bitcast is a `narrowing` if the input element type bit width larger than
+  // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
+  bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
+  int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
+                                 : outElemTyBitWidth / inElemTyBitWidth;
+  const LaneLayout &sourceLaneLayout =
+      resultLayout.getLayout(); // source lane layout is unchanged.
+  ArrayRef<int64_t> currData = resultLayout.getDataAsArrayRef();
+
+  // TODO: Currently we assume that bitcasts does not require cross lane
+  // communication. So each lane must own the required number of elements to
+  // perform the bitcast locally without cross-lane communication.
+  // For 1D vectors, decide how many elements each lane owns based on whether
+  // the bitcast is narrowing or widening.
+  if (rank == 1) {
+    if ((currData[0] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
+      bitcast.emitWarning(
+          "Narrowing bitcast with cross lane communication is not supported.");
+      return;
+    }
+    LaneData sourceLaneData = isNarrowing
+                                  ? LaneData({currData[0] / bitCastRatio})
+                                  : LaneData({currData[0] * bitCastRatio});
 
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
+                                        sourceLaneLayout, sourceLaneData)));
+  }
+  // For nD vectors, Each lane is not allowed to own multiple elements in any
+  // dimension other than the innermost dimension.
+  // TODO: Add support for other case depending on the use case.
+  SmallVector<int64_t, 3> sourceLaneDataStorage(currData.begin(),
+                                                currData.end() - 1);
+  if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
+    bitcast.emitWarning(
+        "Each lane must not own multiple elements in any dimension other than "
+        "the innermost dimension.");
+    return;
+  }
+  // Check if the bitcast requires cross lane communication.
+  if ((currData[rank - 1] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
+    bitcast.emitWarning(
+        "Narrowing bitcast with cross lane communication is not supported.");
+    return;
+  }
+  // Decide lane data based on whether the bitcast is narrowing or widening.
+  int64_t innerMostLaneData = isNarrowing ? currData[rank - 1] / bitCastRatio
+                                          : currData[rank - 1] * bitCastRatio;
+  sourceLaneDataStorage.push_back(innerMostLaneData);
+  LaneData sourceLaneData(sourceLaneDataStorage);
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
+                                      sourceLaneLayout, sourceLaneData)));
 }
 
 /// Propagate the layout of the result to the tensor descriptor and mask
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..61eece55a9bac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -27,6 +27,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -34,6 +35,9 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/Support/LogicalResult.h"
+#include <cstdint>
 
 namespace mlir {
 namespace xegpu {
@@ -146,6 +150,15 @@ static bool hasPackedLayout(xegpu::LayoutAttr layout) {
   return laneData.asArrayRef()[0] != 1;
 }
 
+static bool hasTransposedLayout(xegpu::LayoutAttr layout) {
+  if (layout == xegpu::LayoutAttr())
+    return false;
+  DenseI32ArrayAttr laneLayout = layout.getLaneLayout();
+  if (!laneLayout || laneLayout.size() != 2)
+    return false;
+  return laneLayout.asArrayRef()[0] > 1 && laneLayout.asArrayRef()[1] == 1;
+}
+
 /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
 /// of the original GPUFuncOp to the new GPUFuncOp such that entire body is
 /// contained within a WarpExecuteOnLane0Op.
@@ -500,6 +513,9 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
     xegpu::removeLayoutAttrs(newLoadOp);
     // Set the packed attribute if the layout requires it.
     newLoadOp.setPacked(hasPackedLayout(layout));
+    if (hasTransposedLayout(layout))
+      newLoadOp.setTranspose(
+          DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
     Value distributedVal = newWarpOp.getResult(operandIdx);
     // There can be a conflict between the vector type distributed by the
     // warp op and (xegpu-specific) distributed type supported by the load
@@ -811,6 +827,135 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct MemrefExtractAlignedPointerAsIndexDistribution final
+    : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(
+        warpOp, llvm::IsaPred<memref::ExtractAlignedPointerAsIndexOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp,
+          "warp result is not a xegpu::MemrefExtractAlignedPointerAsIndex op");
+    auto extractOp =
+        operand->get().getDefiningOp<memref::ExtractAlignedPointerAsIndexOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, extractOp.getSource(),
+        TypeRange{extractOp.getSource().getType()}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, newWarpOp.getLoc(), extractOp.getType(),
+        newWarpOp.getResult(newRetIndices[0]));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+    return success();
+  }
+};
+
+struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::BitCastOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a vector::BitCast op");
+    auto bitcastOp = operand->get().getDefiningOp<vector::BitCastOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    VectorType distributedSourceType =
+        getDistVecTypeBasedOnLaneLayout(
+            xegpu::getLayoutAttr(bitcastOp.getSource()),
+            bitcastOp.getSourceVectorType())
+            .value_or(VectorType());
+    if (!distributedSourceType)
+      return rewriter.notifyMatchFailure(
+          bitcastOp, "Failed to distribute the source vector type in "
+                     "vector::BitCast op");
+    VectorType distributedResultType =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    if (distributedSourceType.getRank() != 2 ||
+        distributedResultType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          bitcastOp, "the source or result vector of the bitcast op "
+                     "are not 2D vectors");
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, bitcastOp.getSource(),
+        TypeRange{distributedSourceType}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newBitcastOp = vector::BitCastOp::create(
+        rewriter, newWarpOp.getLoc(), distributedResultType,
+        newWarpOp.getResult(newRetIndices[0]));
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newBitcastOp.getResult());
+    return success();
+  }
+};
+
+struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::TransposeOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a vector::Transpose op");
+    auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    xegpu::LayoutAttr sourceLayout =
+        xegpu::getLayoutAttr(transposeOp.getVector());
+    xegpu::LayoutAttr resultLayout =
+        xegpu::getLayoutAttr(transposeOp.getResult());
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "the source or result vector of the transpose op lacks layout "
+          "attribute");
+    ArrayRef<int> sourceLaneLayout = sourceLayout.getLaneLayout().asArrayRef();
+    ArrayRef<int> resultLaneLayout = resultLayout.getLaneLayout().asArrayRef();
+    ArrayRef<int> sourceLaneData = sourceLayout.getLaneData().asArrayRef();
+    ArrayRef<int> resultLaneData = resultLayout.getLaneData().asArrayRef();
+    if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
+      return rewriter.notifyMatchFailure(
+          transposeOp, "the source or result vector of the transpose op "
+                       "does not have 2D layout");
+    auto is2DTranspose = [](ArrayRef<int> input, ArrayRef<int> output) {
+      return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
+             input[1] == output[0];
+    };
+
+    if (!is2DTranspose(sourceLaneLayout, resultLaneLayout) ||
+        !is2DTranspose(sourceLaneData, resultLaneData))
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "the source or result vector layouts must be transposes of each "
+          "other");
+    FailureOr<VectorType> distributedSourceTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout,
+                                        transposeOp.getSourceVectorType());
+    if (failed(distributedSourceTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          transposeOp, "Failed to distribute the source vector type in "
+                       "vector::Transpose op");
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, transposeOp.getVector(),
+        TypeRange{distributedSourceTypeOrFailure.value()}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newTransposeOp = vector::TransposeOp::create(
+        rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
+        transposeOp.getPermutation());
+    Value distributedVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -825,7 +970,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
   patterns.add<CreateNdDescDistribution, StoreNdDistribution,
                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               UpdateNdOffsetDistribution, GpuBarrierDistribution>(
+               UpdateNdOffsetDistribution, GpuBarrierDistribution,
+               VectorTransposeDistribution, VectorBitcastDistribution,
+               MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext());
 }
 
@@ -903,14 +1050,47 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
                       int64_t warpSz) { return Value(); };
   vector::populatePropagateWarpVectorDistributionPatterns(
       patterns, distributionFn, shuffleFn);
+
+  auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
+                          vector::CombiningKind kind, uint32_t size) {
+    // First reduce on a single thread to get per lane reduction value.
+    Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+    // Parallel reduction using butterfly shuffles.
+    for (uint64_t i = 1; i < size; i <<= 1) {
+      Value shuffled =
+          builder
+              .create<gpu::ShuffleOp>(loc, laneVal, i,
+                                      /*width=*/size,
+                                      /*mode=*/gpu::ShuffleMode::XOR)
+              .getShuffleResult();
+      laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
+    }
+    return laneVal;
+  };
+
+  vector::populateDistributeReduction(patterns, warpReduction);
   if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
     signalPassFailure();
     return;
   }
 
-  // Step 4: Finllay, clean up UnrealizedConversionCastOps that were inserted
+  // Step 4: Finally, clean up UnrealizedConversionCastOps that were inserted
   // due to tensor desc type mismatches created by using upstream distribution
-  // patterns (scf.for)
+  // patterns (scf.for). This cleanup should only be done if all the ops are
+  // distributed successfully, if some ops are still not distributed and remains
+  // inside any WarpExecuteOnLane0Op we avoid this simplication step to avoid
+  // breaking the IR.
+  bool foundWarpOp = false;
+  getOperation()->walk([&](gpu::WarpExecuteOnLane0Op warpOp) {
+    // Look for WarpOps that are not trivially dead.
+    if (isOpTriviallyDead(warpOp))
+      return WalkResult::advance();
+    foundWarpOp = true;
+    return WalkResult::interrupt();
+  });
+  if (foundWarpOp)
+    return;
+
   getOperation()->walk([&](mlir::UnrealizedConversionCastOp op) {
     // We are only interested in UnrealizedConversionCastOps there were added
     // for resolving SIMT type mismatches.
@@ -929,7 +1109,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
            "Unrealized conversion cast must have tensor descriptor types");
 
     // tensor_desc<shape, layout> -> tensor_desc<shape> Type of conversions.
-    // This occurs iside scf.for body to resolve the block argument type to
+    // This occurs inside scf.for body to resolve the block argument type to
     // SIMT type.
     if (inputDescType.getLayout()) {
       auto argument = mlir::dyn_cast<mlir::BlockArgument>(input);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 0214d84f2c16f..4cbe4db271ad6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -181,6 +181,23 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
   return
 }
 
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
+// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<16x8xi32> -> vector<16x8xi32>
+  %4 = vector.bitcast %3 : vector<16x8xi32> to vector<16x16xf16>
+  %5 = vector.transpose %4, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
+  %6 = xegpu.dpas %2, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %6, %7  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  return
+}
+
 // -----
 // CHECK-LABEL: func.func @binary_op_one_use(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,

>From ad5d0a88a4f065dc3720d977c8e3d125c5b768b8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 17:58:25 +0000
Subject: [PATCH 02/15] rename getLayoutAttr util

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 66 +++++++++++++++++++
 .../mlir/Dialect/XeGPU/IR/XeGPUDialect.td     |  2 +-
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 27 ++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 25 ++++---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 16 ++---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  5 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 26 ++++----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 30 ++++-----
 8 files changed, 132 insertions(+), 65 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index b4d696444cc44..5b4b376157c00 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -185,6 +185,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Check the availability of workgroup level layouts",
                     "bool",
                     "isForWorkgroup">,
+    InterfaceMethod<"Check the availability of subgroup level layouts",
+                    "bool",
+                    "isForSubgroup">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
@@ -202,6 +205,15 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Get the SgData field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
+    InterfaceMethod<"Get the InstData field of the attribute as integer array",
+                    "std::optional<SmallVector<int64_t>>",
+                    "getInstDataAsInt">,
+    InterfaceMethod<"Get the LaneLayout field of the attribute as integer array",
+                    "std::optional<SmallVector<int64_t>>",
+                    "getLaneLayoutAsInt">,
+    InterfaceMethod<"Get the LaneData field of the attribute as integer array",
+                    "std::optional<SmallVector<int64_t>>",
+                    "getLaneDataAsInt">,
     InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
                     "xegpu::DistributeLayoutAttr",
                     "dropSgLayoutAndData">,
@@ -388,6 +400,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return std::nullopt;
     }
 
+    std::optional<SmallVector<int64_t>> getInstDataAsInt() const {
+      if (DenseI32ArrayAttr inst = getInstData())
+        return llvm::to_vector_of<int64_t>(inst.asArrayRef());
+      return std::nullopt;
+    }
+
+    std::optional<SmallVector<int64_t>> getLaneLayoutAsInt() const {
+      if (DenseI32ArrayAttr layout = getLaneLayout())
+        return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+      return std::nullopt;
+    }
+
+    std::optional<SmallVector<int64_t>> getLaneDataAsInt() const {
+      if (DenseI32ArrayAttr data = getLaneData())
+        return llvm::to_vector_of<int64_t>(data.asArrayRef());
+      return std::nullopt;
+    }
+
     /// Delinearizes a linear subgroup ID into its multidimensional indices
     /// based on the effective subgroup layout.
     FailureOr<SmallVector<Value>>
@@ -488,6 +518,42 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return std::nullopt;
     }
 
+    /// Returns the InstData of the attribute, computed by applying
+    /// the slice dimensions to the underlying LayoutAttr.
+    std::optional<SmallVector<int64_t>> getInstDataAsInt() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      if (auto inst = parent.getInstDataAsInt()) {
+        ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+        return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*inst), dims);
+      }
+      return std::nullopt;
+    }
+
+    /// Returns the LaneLayout of the attribute, computed by applying
+    /// the slice dimensions to the underlying LayoutAttr.
+    std::optional<SmallVector<int64_t>> getLaneLayoutAsInt() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      if (auto layout = parent.getLaneLayoutAsInt()) {
+        ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+        return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
+      }
+      return std::nullopt;
+    }
+
+    /// Returns the LaneData of the attribute, computed by applying
+    /// the slice dimensions to the underlying LayoutAttr.
+    std::optional<SmallVector<int64_t>> getLaneDataAsInt() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      if (auto data = parent.getLaneDataAsInt()) {
+        ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+        return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
+      }
+      return std::nullopt;
+    }
+
     SliceAttr dropSgLayoutAndData() {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 76d58e5ea2424..c173b93face98 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -40,7 +40,7 @@ def XeGPU_Dialect : Dialect {
     let extraClassDeclaration = [{
       /// Checks if the given shape can be evenly distributed based on the layout
       /// and data factors provided by the LayoutAttr.
-      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
 
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index b2b2d3ab85231..010199083add9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -21,6 +21,7 @@ class ValueRange;
 class TypeConverter;
 
 namespace xegpu {
+class DistributeLayoutAttr;
 class LayoutAttr;
 class TensorDescType;
 } // namespace xegpu
@@ -60,22 +61,22 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
 FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
-/// Return the attribute name for the OpOperand to attach LayoutAttr
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
 std::string getLayoutName(const OpOperand &operand);
 
-/// Return the attribute name for the OpResult to attach LayoutAttr
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
 std::string getLayoutName(const OpResult result);
 
-/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
-/// values, the LayoutAttr is extracted from the TensorDescType itself. For
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For TensorDescType
+/// values, the DistributeLayoutAttr is extracted from the TensorDescType itself. For
 /// other values, it is obtained from the attributes of the defining operation.
-/// Returns nullptr if no LayoutAttr is found.
-LayoutAttr getLayoutAttr(const Value value);
+/// Returns nullptr if no DistributeLayoutAttr is found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
 
-/// Retrieves the LayoutAttr associated with a given OpOperand. It will
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
 /// first check the operand_layout_{id} of the owner operation. If not found,
 /// it will check the operand itself and its defining op.
-LayoutAttr getLayoutAttr(const OpOperand &opr);
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
 
 /// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
 template <typename T,
@@ -83,23 +84,23 @@ template <typename T,
                                       std::is_same_v<T, OpResult>>>
 void removeLayoutAttr(const T &operandOrResult);
 
-/// Removes the LayoutAttr for each OpOperand and OpResult of the given
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given
 /// operation if they exist. If the operation contains regions, it is also
 /// applied recursively to the contained operations
 void removeLayoutAttrs(Operation *op);
 
-/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
+/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
 /// it to the owner's dictionary attributes
 template <typename T,
           typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
                                       std::is_same_v<T, OpResult>>>
-void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
+void setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout);
 
-/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
+/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
 /// If the operation contains regions, it is also applied recursively to the
 /// contained operations
 void setLayoutAttrs(Operation *op,
-                    function_ref<LayoutAttr(Value)> getLayoutImpl);
+                    function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
 
 /// Extract a set of small vectors from a value with a given shape using
 /// vector.extract_stride_slice
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index a2d708be0e937..2079848c878a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -91,7 +91,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc,
 // Checks if the given shape can be evenly distributed based on the layout
 // and data factors provided by the LayoutAttr.
 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
-                                         xegpu::LayoutAttr attr) {
+                                         xegpu::DistributeLayoutAttr attr) {
   assert(attr && "Layout attribute is missing.");
 
   // Checks whether the given shape can be evenly distributed using the
@@ -104,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
   // smaller than `layout[i] * data[i]`, allowing multiple compute units to
   // share the data.
   auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
-                           DenseI32ArrayAttr layout, DenseI32ArrayAttr data,
+                           std::optional<SmallVector<int64_t>> layout,
+                           std::optional<SmallVector<int64_t>> data,
                            bool rr = true) -> optional<SmallVector<int64_t>> {
     llvm::SmallVector<int64_t> newShape(shape);
     if (layout) {
-      auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef());
-      if (vec.size() != shape.size())
+      if ((*layout).size() != shape.size())
         return std::nullopt;
-      auto ratio = computeShapeRatio(shape, vec);
+      auto ratio = computeShapeRatio(shape, *layout);
       if (!ratio.has_value())
         return std::nullopt;
       newShape = ratio.value();
     }
 
     if (data) {
-      auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef());
-      if (vec.size() != shape.size())
+      if ((*data).size() != shape.size())
         return std::nullopt;
-      auto ratio = computeShapeRatio(newShape, vec);
+      auto ratio = computeShapeRatio(newShape, *data);
       if (!ratio.has_value() && rr)
-        ratio = computeShapeRatio(vec, newShape);
+        ratio = computeShapeRatio(*data, newShape);
       if (!ratio.has_value())
         return std::nullopt;
 
       // if data is not null, we always return it for next phase.
-      newShape = vec;
+      newShape = *data;
     }
     return newShape;
   };
 
   // check the sgLayout and sgData
   auto maybeSgShape =
-      tryDistribute(shape, attr.getSgLayout(), attr.getSgData());
+      tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
   if (!maybeSgShape)
     return false;
   auto sgShape = maybeSgShape.value();
 
   // check InstData, it neither have layout nor need round-robin
   auto maybeInstShape =
-      tryDistribute(sgShape, nullptr, attr.getInstData(), false);
+      tryDistribute(sgShape, std::nullopt, attr.getInstDataAsInt(), false);
   if (!maybeInstShape)
     return false;
   auto instShape = maybeInstShape.value();
 
   // check LaneLayout and LaneData
   auto maybeLaneShape =
-      tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false);
+      tryDistribute(instShape, attr.getLaneLayoutAsInt(), attr.getLaneDataAsInt(), false);
   return maybeLaneShape.has_value();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index b3144e4c1e55d..c62597df1f895 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -140,10 +140,10 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   else
     value = (Value)operandOrResult;
 
-  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
+  xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (auto inst_data = layout.getInstData())
-      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+    if (auto inst_data = layout.getInstDataAsInt())
+      return inst_data.value();
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -204,12 +204,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   // skip the op if any of its operands or results has workgroup level layouts
   bool hasWgLayoutOperands =
       llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
+        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr);
         return layout && layout.isForWorkgroup();
       });
   bool hasWgLayoutResults =
       llvm::any_of(op->getOpResults(), [](OpResult result) {
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result);
         return layout && layout.isForWorkgroup();
       });
   if (hasWgLayoutOperands || hasWgLayoutResults) {
@@ -220,8 +220,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
     Type valTy = value.getType();
     if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
-      xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
-      return layout && layout.getInstData();
+      xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr();
+      return layout && layout.getInstDataAsInt();
     }
     auto shapedType = dyn_cast<ShapedType>(valTy);
     return shapedType && !llvm::equal(tileShape, shapedType.getShape());
@@ -247,7 +247,7 @@ void XeGPUBlockingPass::runOnOperation() {
   // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
   // This ensures that the LayoutAttr remains accessible even if the defining
   // operation is replaced.
-  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
+  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
                                  xegpu::LayoutAttr layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c7fc5ec..de9378bd7a6f6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -841,7 +841,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
-      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+      auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(operand));
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
@@ -882,7 +882,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
     if (vecRank == 0)
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val);
+    // TODO: support more layout types
+    auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(val));
     // If no layout is specified, assume the inner most dimension is distributed
     // for now.
     if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 93b4efcd125ec..c60f9e361bf8e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -406,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (resultTy.getRank() != 2)
       return failure();
 
-    auto originalLayout = xegpu::getLayoutAttr(op.getResult());
+    auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult());
     if (!originalLayout)
       return failure();
 
@@ -470,8 +470,8 @@ struct WgToSgVectorBroadcastOp
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-    if (!layout || !layout.getSgLayout())
+    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
       return failure();
 
     // TODO: Currently only supports cases where the source and result ranks
@@ -487,8 +487,8 @@ struct WgToSgVectorBroadcastOp
 
     // Check if the output layout is distributable
     SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    if (auto maybeSgLayout = layout.getSgLayoutAsInt())
+      sgLayout = *maybeSgLayout;
     else
       return failure();
 
@@ -535,8 +535,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
-    if (!layout || !layout.getSgLayout())
+    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+    if (!layout || !layout.isForWorkgroup())
       return failure();
 
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
@@ -737,8 +737,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     if (!vecAttr || !vecAttr.isSplat() || !vecType)
       return failure();
 
-    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-    if (!layout || !layout.getSgLayout())
+    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
       return failure();
 
     ArrayRef<int64_t> wgShape = vecType.getShape();
@@ -928,7 +928,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
-    auto layout = xegpu::getLayoutAttr(op.getResult());
+    auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
     return isLegal(layout);
   });
 
@@ -947,12 +947,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         auto vecType = dyn_cast<VectorType>(op.getType());
         if (!vecType)
           return true;
-        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
       [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
       });
 
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
@@ -980,7 +980,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
           }
         }
 
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
+        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
         return isLegal(layout);
       });
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6835f64ad8ef7..5ae025ef34739 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -114,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) {
   return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
 }
 
-xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
   if (!value)
     return nullptr;
 
@@ -132,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
 
     // for LoadNdOp, the layout is stored in the tensor descriptor
     if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
-      return getLayoutAttr(loadNd.getTensorDesc());
+      return getDistributeLayoutAttr(loadNd.getTensorDesc());
 
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
-      return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+      return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -144,41 +144,41 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
     if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
       OpOperand *tiedInit = loop.getTiedLoopInit(arg);
       if (tiedInit)
-        return getLayoutAttr(tiedInit->get());
+        return getDistributeLayoutAttr(tiedInit->get());
     }
   }
 
   return nullptr;
 }
 
-xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
-    return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
-  return getLayoutAttr(opr.get());
+    return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+  return getDistributeLayoutAttr(opr.get());
 }
 
 template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout) {
   Operation *owner = operandOrResult.getOwner();
   std::string name = xegpu::getLayoutName(operandOrResult);
-  if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+  if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
     owner->setAttr(name, layout);
 }
 
 // Explicit instantiation for OpResult
 template void
 xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
-                                     const mlir::xegpu::LayoutAttr layout);
+                                     const mlir::xegpu::DistributeLayoutAttr layout);
 
 // Explicit instantiation for OpOperand
 template void
 xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
-                                      const mlir::xegpu::LayoutAttr layout);
+                                      const mlir::xegpu::DistributeLayoutAttr layout);
 
 void xegpu::setLayoutAttrs(Operation *op,
-                           function_ref<LayoutAttr(Value)> getLayoutImpl) {
+                           function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {
     for (OpOperand &opr : nestOp->getOpOperands()) {
       auto layout = getLayoutImpl(opr.get());
@@ -195,7 +195,7 @@ template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
   std::string name = xegpu::getLayoutName(operandOrResult);
-  if (owner->hasAttrOfType<LayoutAttr>(name))
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
     owner->removeAttr(name);
 }
 
@@ -306,7 +306,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
       if (!inputTy || !resultTy)
         return WalkResult::skip();
 
-      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+      xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(input);
       if (!layout)
         return WalkResult::skip();
 
@@ -344,7 +344,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
   }
 
   { // perform the conversion from RankedTensorType to VectorType based on the
-    // LayoutAttr
+    // DistributeLayoutAttr
 
     // Handle the UnrealizedConversionCastOp introduced by the first step.
     // For vector->RankedTensorType, it will simply forward the inputs.

>From 0e34f36690a34f071afd181649b8f86c90dde9b4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:10:49 +0000
Subject: [PATCH 03/15] refine

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 17 +++++++++++---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  5 ++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  4 ++--
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  7 +++---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 10 ++++----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 23 ++++++++++---------
 6 files changed, 40 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 010199083add9..7089559d0c51b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -73,11 +73,21 @@ std::string getLayoutName(const OpResult result);
 /// Returns nullptr if no DistributeLayoutAttr is found.
 DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
 
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const Value value) {
+  return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
+}
+
 /// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
 /// first check the operand_layout_{id} of the owner operation. If not found,
 /// it will check the operand itself and its defining op.
 DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
 
+template <typename AttrTy>
+AttrTy getDistributeLayoutAttrOfType(const OpOperand &opr) {
+  return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(opr));
+}
+
 /// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
 template <typename T,
           typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
@@ -94,13 +104,14 @@ void removeLayoutAttrs(Operation *op);
 template <typename T,
           typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
                                       std::is_same_v<T, OpResult>>>
-void setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout);
+void setDistributeLayoutAttr(const T &operandOrResult,
+                             const DistributeLayoutAttr layout);
 
 /// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
 /// If the operation contains regions, it is also applied recursively to the
 /// contained operations
-void setLayoutAttrs(Operation *op,
-                    function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
+void setDistributeLayoutAttrs(
+    Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
 
 /// Extract a set of small vectors from a value with a given shape using
 /// vector.extract_stride_slice
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index c62597df1f895..2e3e40ed2d457 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -247,7 +247,8 @@ void XeGPUBlockingPass::runOnOperation() {
   // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
   // This ensures that the LayoutAttr remains accessible even if the defining
   // operation is replaced.
-  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
+  xegpu::setDistributeLayoutAttrs(
+      op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); });
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
                                  xegpu::LayoutAttr layout) {
@@ -377,7 +378,7 @@ void XeGPUBlockingPass::runOnOperation() {
       if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
         op->removeAttr(name);
         if (!isa<LoopLikeOpInterface>(op))
-          xegpu::setLayoutAttr(result, layout.dropInstData());
+          xegpu::setDistributeLayoutAttr(result, layout.dropInstData());
       }
     }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bef88042fc663..5cb47b2accd68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    xegpu::setLayoutAttr(result, layout);
+    xegpu::setDistributeLayoutAttr(result, layout);
   }
   return success();
 }
@@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder,
       // If the type is a vector type and this region argument is an OpResult,
       // set the layout attribute on the OpResult.
       if (auto result = dyn_cast<OpResult>(successorInput))
-        xegpu::setLayoutAttr(result, successorOperandLayout);
+        xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
     }
   }
   return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index de9378bd7a6f6..e48e2180197ec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -841,14 +841,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       if (!isa<VectorType>(operand.get().getType()))
         continue;
 
-      auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(operand));
+      auto layout =
+          xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
       if (!layout) {
         op->emitError("Could not find layout attribute for operand ")
             << operand.getOperandNumber() << " of operation " << op->getName();
         signalPassFailure();
         return;
       }
-      xegpu::setLayoutAttr(operand, layout);
+      xegpu::setDistributeLayoutAttr(operand, layout);
     }
   });
   // Step 2: Move all operations of a GPU function inside
@@ -883,7 +884,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       return AffineMap::get(val.getContext());
     // Get the layout of the vector type.
     // TODO: support more layout types
-    auto layout = dyn_cast<xegpu::LayoutAttr>(xegpu::getDistributeLayoutAttr(val));
+    auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val);
     // If no layout is specified, assume the inner most dimension is distributed
     // for now.
     if (!layout)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c60f9e361bf8e..a8700ca73efc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -429,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
         VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
                                            resultTy.getElementType());
         tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
-        xegpu::setLayoutAttr(cast<OpResult>(tmpC),
-                             originalLayout.dropSgLayoutAndData());
+        xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC),
+                                       originalLayout.dropSgLayoutAndData());
 
         newDpasOps.push_back(tmpC);
       }
@@ -508,8 +508,8 @@ struct WgToSgVectorBroadcastOp
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
                                                       newResultType, operand);
-      xegpu::setLayoutAttr(newBroadcast->getResult(0),
-                           layout.dropSgLayoutAndData());
+      xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+                                     layout.dropSgLayoutAndData());
       newBroadcastOps.push_back(newBroadcast.getResult());
     }
 
@@ -755,7 +755,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     auto cstOp =
         arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
     if (auto newLayout = layout.dropSgLayoutAndData())
-      xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+      xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
     SmallVector<Value> newConsts(count, cstOp);
 
     rewriter.replaceOpWithMultiple(op, {newConsts});
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 5ae025ef34739..1d4de68754c20 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -160,7 +160,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr)
 }
 
 template <typename T, typename>
-void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout) {
+void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
+                                    const DistributeLayoutAttr layout) {
   Operation *owner = operandOrResult.getOwner();
   std::string name = xegpu::getLayoutName(operandOrResult);
   if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
@@ -168,25 +169,25 @@ void xegpu::setLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr l
 }
 
 // Explicit instantiation for OpResult
-template void
-xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
-                                     const mlir::xegpu::DistributeLayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
+    const mlir::OpResult &result,
+    const mlir::xegpu::DistributeLayoutAttr layout);
 
 // Explicit instantiation for OpOperand
-template void
-xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
-                                      const mlir::xegpu::DistributeLayoutAttr layout);
+template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
+    const mlir::OpOperand &operand,
+    const mlir::xegpu::DistributeLayoutAttr layout);
 
-void xegpu::setLayoutAttrs(Operation *op,
-                           function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
+void xegpu::setDistributeLayoutAttrs(
+    Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {
     for (OpOperand &opr : nestOp->getOpOperands()) {
       auto layout = getLayoutImpl(opr.get());
-      setLayoutAttr(opr, layout);
+      setDistributeLayoutAttr(opr, layout);
     }
     for (OpResult result : nestOp->getOpResults()) {
       auto layout = getLayoutImpl(result);
-      setLayoutAttr(result, layout);
+      setDistributeLayoutAttr(result, layout);
     }
   });
 }

>From a84014ff42002dc5b036558c62e5387536e74019 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:12:17 +0000
Subject: [PATCH 04/15] format

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 25 ++++++++++---------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  4 +--
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  9 ++++---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 12 ++++++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  6 +++--
 5 files changed, 33 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 7089559d0c51b..82fd70571c022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -67,10 +67,11 @@ std::string getLayoutName(const OpOperand &operand);
 /// Return the attribute name for the OpResult to attach DistributeLayoutAttr
 std::string getLayoutName(const OpResult result);
 
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For TensorDescType
-/// values, the DistributeLayoutAttr is extracted from the TensorDescType itself. For
-/// other values, it is obtained from the attributes of the defining operation.
-/// Returns nullptr if no DistributeLayoutAttr is found.
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
 DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
 
 template <typename AttrTy>
@@ -78,9 +79,9 @@ AttrTy getDistributeLayoutAttrOfType(const Value value) {
   return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
 }
 
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
-/// first check the operand_layout_{id} of the owner operation. If not found,
-/// it will check the operand itself and its defining op.
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
 DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
 
 template <typename AttrTy>
@@ -94,8 +95,8 @@ template <typename T,
                                       std::is_same_v<T, OpResult>>>
 void removeLayoutAttr(const T &operandOrResult);
 
-/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given
-/// operation if they exist. If the operation contains regions, it is also
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
+/// given operation if they exist. If the operation contains regions, it is also
 /// applied recursively to the contained operations
 void removeLayoutAttrs(Operation *op);
 
@@ -107,9 +108,9 @@ template <typename T,
 void setDistributeLayoutAttr(const T &operandOrResult,
                              const DistributeLayoutAttr layout);
 
-/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
-/// If the operation contains regions, it is also applied recursively to the
-/// contained operations
+/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
+/// operation. If the operation contains regions, it is also applied recursively
+/// to the contained operations
 void setDistributeLayoutAttrs(
     Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 2079848c878a3..6de6049facfc6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -147,8 +147,8 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
   auto instShape = maybeInstShape.value();
 
   // check LaneLayout and LaneData
-  auto maybeLaneShape =
-      tryDistribute(instShape, attr.getLaneLayoutAsInt(), attr.getLaneDataAsInt(), false);
+  auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
+                                      attr.getLaneDataAsInt(), false);
   return maybeLaneShape.has_value();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2e3e40ed2d457..45fed8e548a89 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -140,7 +140,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   else
     value = (Value)operandOrResult;
 
-  xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult);
+  xegpu::DistributeLayoutAttr layout =
+      xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
     if (auto inst_data = layout.getInstDataAsInt())
       return inst_data.value();
@@ -204,12 +205,14 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   // skip the op if any of its operands or results has workgroup level layouts
   bool hasWgLayoutOperands =
       llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
-        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr);
+        xegpu::DistributeLayoutAttr layout =
+            xegpu::getDistributeLayoutAttr(opr);
         return layout && layout.isForWorkgroup();
       });
   bool hasWgLayoutResults =
       llvm::any_of(op->getOpResults(), [](OpResult result) {
-        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result);
+        xegpu::DistributeLayoutAttr layout =
+            xegpu::getDistributeLayoutAttr(result);
         return layout && layout.isForWorkgroup();
       });
   if (hasWgLayoutOperands || hasWgLayoutResults) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a8700ca73efc4..518c7817a516e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -470,7 +470,8 @@ struct WgToSgVectorBroadcastOp
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -535,7 +536,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 
     ArrayRef<int64_t> wgShape = resultType.getShape();
 
-    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op->getResult(0));
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -737,7 +739,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
     if (!vecAttr || !vecAttr.isSplat() || !vecType)
       return failure();
 
-    xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
@@ -980,7 +983,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
           }
         }
 
-        xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+        xegpu::DistributeLayoutAttr layout =
+            xegpu::getDistributeLayoutAttr(op->getResult(0));
         return isLegal(layout);
       });
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 1d4de68754c20..cac1ffe4d3bc3 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -151,7 +151,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
   return nullptr;
 }
 
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))
@@ -307,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
       if (!inputTy || !resultTy)
         return WalkResult::skip();
 
-      xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(input);
+      xegpu::DistributeLayoutAttr layout =
+          xegpu::getDistributeLayoutAttr(input);
       if (!layout)
         return WalkResult::skip();
 

>From f3af2c307597bf13a04579b3235b45af7ea10392 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 21 Aug 2025 18:59:45 +0000
Subject: [PATCH 05/15] update convert_layout

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td            | 3 +++
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td              | 4 ++--
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp         | 6 +++---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 5 +++--
 4 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5b4b376157c00..77e3c257f234e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -217,6 +217,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
                     "xegpu::DistributeLayoutAttr",
                     "dropSgLayoutAndData">,
+    InterfaceMethod<"Derive a new layout by dropping InstData",
+                    "xegpu::DistributeLayoutAttr",
+                    "dropInstData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
                     "FailureOr<SmallVector<Value>>",
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index ab471a1f33ef9..2f6671c5e37cc 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1162,8 +1162,8 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
       the IR is lowered to WI level because that is the end result of all distributions.
     }];
     let arguments = (ins XeGPU_VectorType: $source,
-                         XeGPU_LayoutAttr: $input_layout,
-                         XeGPU_LayoutAttr: $target_layout);
+                         DistributeLayoutAttr: $input_layout,
+                         DistributeLayoutAttr: $target_layout);
     let results = (outs XeGPU_VectorType: $result);
     let assemblyFormat = [{
         $source prop-dict attr-dict `:` type($source)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 45fed8e548a89..80e9d4d25b06c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -84,9 +84,9 @@ struct ConvertLayoutOpPattern
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op,
                                 PatternRewriter &rewriter) const override {
-    xegpu::LayoutAttr input_layout = op.getInputLayoutAttr();
-    xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr();
-    if (!input_layout.getInstData() || !target_layout.getInstData())
+    xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr();
+    xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr();
+    if (!input_layout.getInstDataAsInt() || !target_layout.getInstDataAsInt())
       return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp.");
 
     input_layout = input_layout.dropInstData();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 518c7817a516e..4fb962908793f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -613,8 +613,9 @@ struct WgToSgConvertLayoutOp
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    xegpu::LayoutAttr input = op.getInputLayout();
-    xegpu::LayoutAttr target = op.getTargetLayout();
+    // TODO: currently, we only support LayoutAttr
+    auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
+    auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
 
     if (!input || !target || !input.isForWorkgroup() ||
         !target.isForWorkgroup())

>From ee5baca1ccae6549aca46693814f9c8ea8b995e7 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 22:54:47 +0000
Subject: [PATCH 06/15] save work

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 42 ++++++----------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 48 +++++++++++++++++--
 2 files changed, 58 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 10c2759493477..8dce63b80f373 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -639,46 +639,32 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
                                  : outElemTyBitWidth / inElemTyBitWidth;
   const LaneLayout &sourceLaneLayout =
       resultLayout.getLayout(); // source lane layout is unchanged.
-  ArrayRef<int64_t> currData = resultLayout.getDataAsArrayRef();
+  ArrayRef<int64_t> outData = resultLayout.getDataAsArrayRef();
 
   // TODO: Currently we assume that bitcasts does not require cross lane
   // communication. So each lane must own the required number of elements to
   // perform the bitcast locally without cross-lane communication.
-  // For 1D vectors, decide how many elements each lane owns based on whether
-  // the bitcast is narrowing or widening.
-  if (rank == 1) {
-    if ((currData[0] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
-      bitcast.emitWarning(
-          "Narrowing bitcast with cross lane communication is not supported.");
-      return;
-    }
-    LaneData sourceLaneData = isNarrowing
-                                  ? LaneData({currData[0] / bitCastRatio})
-                                  : LaneData({currData[0] * bitCastRatio});
-
-    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
-                                        sourceLaneLayout, sourceLaneData)));
+  int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
+  if (outInnerBitsPerLane < inElemTyBitWidth) {
+    bitcast.emitWarning(
+        "Narrowing bitcast with cross lane communication is not supported.");
+    return;
   }
-  // For nD vectors, Each lane is not allowed to own multiple elements in any
-  // dimension other than the innermost dimension.
-  // TODO: Add support for other case depending on the use case.
-  SmallVector<int64_t, 3> sourceLaneDataStorage(currData.begin(),
-                                                currData.end() - 1);
+  // Check if each lane owns a single element in all dimensions except the
+  // innermost dimension. For example, if the result layout is [1, 16][2, 1], we
+  // are not allowed to bitcast such vectors.
+  // TODO: Relax this based on use cases.
+  SmallVector<int64_t, 3> sourceLaneDataStorage(outData.begin(),
+                                                outData.end() - 1);
   if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
     bitcast.emitWarning(
         "Each lane must not own multiple elements in any dimension other than "
         "the innermost dimension.");
     return;
   }
-  // Check if the bitcast requires cross lane communication.
-  if ((currData[rank - 1] * outElemTyBitWidth) % inElemTyBitWidth != 0) {
-    bitcast.emitWarning(
-        "Narrowing bitcast with cross lane communication is not supported.");
-    return;
-  }
   // Decide lane data based on whether the bitcast is narrowing or widening.
-  int64_t innerMostLaneData = isNarrowing ? currData[rank - 1] / bitCastRatio
-                                          : currData[rank - 1] * bitCastRatio;
+  int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
+                                          : outData[rank - 1] * bitCastRatio;
   sourceLaneDataStorage.push_back(innerMostLaneData);
   LaneData sourceLaneData(sourceLaneDataStorage);
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 4cbe4db271ad6..994fa44cab0b6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -164,9 +164,14 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
 
 // -----
 // CHECK-LABEL: func.func @vector_bitcast_i16_to_f16(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xi16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xi16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xi16> to vector<8x16xf16>
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xi16> to vector<16x16xf16>
+// CHECK:       %[[LOAD0:.*]] = xegpu.load_nd %{{.*}}  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:     !xegpu.tensor_desc<8x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xi16>
+// CHECK:       %[[LOAD1:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+// CHECK-SAME:     !xegpu.tensor_desc<16x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xi16>
+// CHECK:       %{{.*}} = vector.bitcast %[[LOAD0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:      vector<8x16xi16> to vector<8x16xf16>
+// CHECK:       %{{.*}} = vector.bitcast %[[LOAD1]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+// CHECK-SAME:      vector<16x16xi16> to vector<16x16xf16>
 func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x16xi16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16>
@@ -183,7 +188,10 @@ func.func @vector_bitcast_i16_to_f16(%arg0: memref<8x16xi16>, %arg1: memref<16x1
 
 // -----
 // CHECK-LABEL: func.func @vector_bitcast_i32_to_f16(
-// CHECK: %{{.*}} = vector.bitcast %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+// CHECK:      %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:     !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:     vector<16x8xi32> to vector<16x16xf16>
 func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -198,6 +206,38 @@ func.func @vector_bitcast_i32_to_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x8
   return
 }
 
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_i16_to_i32(
+// CHECK:      %[[LOAD:.*]] = xegpu.load_nd %{{.*}}  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+// CHECK-SAME:     !xegpu.tensor_desc<8x32xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xi16>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:     vector<8x32xi16> to vector<8x16xi32>
+func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16xi32>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi16> -> !xegpu.tensor_desc<8x32xi16>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xi16> -> vector<8x32xi16>
+  %3 = vector.bitcast %2 : vector<8x32xi16> to vector<8x16xi32>
+  xegpu.store_nd %3, %1  : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
+// CHECK-NOT: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = {{.*}}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+// CHECK:     %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:     vector<8x16xi32> to vector<8x32xi16>
+func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<8x32xi16> -> !xegpu.tensor_desc<8x32xi16>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+  %3 = vector.bitcast %2 : vector<8x16xi32> to vector<8x32xi16>
+  xegpu.store_nd %3, %1  : vector<8x32xi16>, !xegpu.tensor_desc<8x32xi16>
+  return
+}
+
+
 // -----
 // CHECK-LABEL: func.func @binary_op_one_use(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,

>From 621122c50d7df5adb6ed33d94b8055fdc480ecdd Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 21 Aug 2025 23:14:40 +0000
Subject: [PATCH 07/15] save work

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8dce63b80f373..d8c447dd46338 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -107,7 +107,6 @@ struct LayoutInfo {
 private:
   LaneLayout laneLayout;
   LaneData laneData;
-  xegpu::LayoutAttr layoutAttr;
 
 public:
   LayoutInfo() = default;
@@ -464,7 +463,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  // Only consider 1D -> 2D broadcasts or 2D -> 2D broadcasts.
+  // Only consider vector to vector broadcasts for now.
   VectorType resultTy = broadcast.getResultVectorType();
   VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
   if (!sourceTy) {
@@ -472,7 +471,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
     return;
   }
 
-  // Only conside 2D -> 2D broadcast.
+  // Only consider 2D -> 2D broadcast.
   if (sourceTy.getRank() != 2 || resultTy.getRank() != 2) {
     broadcast.emitWarning("Expecting source type to be 2D vector and "
                           "result type to be 2D vector.");

>From 35c64895111db5d7019a64078fbe719dce317b95 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 22 Aug 2025 14:45:35 +0000
Subject: [PATCH 08/15] fix compilation error in clang

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 82fd70571c022..bad734dbfd9f0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 namespace mlir {

>From b912c21cf84eee0b574f4acc8db036270d9efb36 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 25 Aug 2025 22:08:11 +0000
Subject: [PATCH 09/15] save work

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 289 +++++++++++-------
 1 file changed, 173 insertions(+), 116 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index d8c447dd46338..5bba85dd4d3bc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
@@ -29,6 +30,7 @@
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Casting.h"
@@ -36,6 +38,7 @@
 #include "llvm/Support/InterleavedRange.h"
 #include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
+#include <cstdint>
 
 namespace mlir {
 namespace xegpu {
@@ -58,30 +61,32 @@ namespace {
 
 /// Helper class to store the ND layout of lanes within a subgroup and data
 /// owned by each lane.
-struct Layout {
-  SmallVector<int64_t, 3> layout;
-  Layout() = default;
-  Layout(std::initializer_list<int64_t> list) : layout(list) {}
-  Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
-  void print(llvm::raw_ostream &os) const;
-  size_t size() const { return layout.size(); }
-  int64_t operator[](size_t idx) const;
-};
-
-int64_t Layout::operator[](size_t idx) const {
-  assert(idx < layout.size() && "Index out of bounds");
-  return layout[idx];
-}
-
-void Layout::print(llvm::raw_ostream &os) const {
-  os << llvm::interleaved_array(layout);
-}
-
-/// LaneLayout represents the logical layout of lanes within a subgroup when it
-/// accesses some value. LaneData represents the logical layout of data owned by
-/// each work item.
-using LaneLayout = Layout;
-using LaneData = Layout;
+// struct Layout {
+//   SmallVector<int64_t, 3> layout;
+//   Layout() = default;
+//   Layout(std::initializer_list<int64_t> list) : layout(list) {}
+//   Layout(SmallVector<int64_t, 3> &list) : layout(list) {}
+//   void print(llvm::raw_ostream &os) const;
+//   size_t size() const { return layout.size(); }
+//   int64_t operator[](size_t idx) const;
+// };
+
+// int64_t Layout::operator[](size_t idx) const {
+//   assert(idx < layout.size() && "Index out of bounds");
+//   return layout[idx];
+// }
+
+// void Layout::print(llvm::raw_ostream &os) const {
+//   os << llvm::interleaved_array(layout);
+// }
+
+// /// LaneLayout represents the logical layout of lanes within a subgroup when
+// it
+// /// accesses some value. LaneData represents the logical layout of data owned
+// by
+// /// each work item.
+// using LaneLayout = Layout;
+// using LaneData = Layout;
 
 //===----------------------------------------------------------------------===//
 // LayoutInfo
@@ -105,13 +110,14 @@ using LaneData = Layout;
 
 struct LayoutInfo {
 private:
-  LaneLayout laneLayout;
-  LaneData laneData;
+  mlir::Attribute storage = nullptr;
 
 public:
   LayoutInfo() = default;
-  LayoutInfo(const LaneLayout &layout, const LaneData &data)
-      : laneLayout(layout), laneData(data) {}
+  LayoutInfo(const xegpu::LayoutAttr &layout) : storage(layout) {}
+  LayoutInfo(const xegpu::SliceAttr &slice) : storage(slice) {
+    storage = slice.flatten();
+  }
 
   // Two lattice values are equal if they have `some` layout. The actual
   // content of the layout does not matter.
@@ -125,24 +131,44 @@ struct LayoutInfo {
 
   void print(raw_ostream &os) const;
 
-  bool isAssigned() const {
-    return laneLayout.size() > 0 && laneData.size() > 0;
-  }
+  bool isAssigned() const { return storage != nullptr; }
+
+  LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
 
-  LayoutInfo getTransposedLayout(ArrayRef<int64_t> permutation) const;
+  ArrayRef<int> getLaneLayout() const {
+    if (!isAssigned())
+      return {};
+    if (isa<xegpu::LayoutAttr>(storage))
+      return cast<xegpu::LayoutAttr>(storage).getLaneLayout().asArrayRef();
+    xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
+    assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
+           "Slice parent must be a LayoutAttr");
+    auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
+    return parent.getLaneLayout().asArrayRef();
+  }
+  ArrayRef<int> getLaneData() const {
+    if (!isAssigned())
+      return {};
+    if (isa<xegpu::LayoutAttr>(storage))
+      return cast<xegpu::LayoutAttr>(storage).getLaneData().asArrayRef();
+    xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
+    assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
+           "Slice parent must be a LayoutAttr");
+    auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
+    return parent.getLaneData().asArrayRef();
+  }
+  bool isSliceLayout() const {
+    if (!isAssigned())
+      return false;
+    return isa<xegpu::SliceAttr>(storage);
+  }
 
-  const LaneLayout &getLayout() const { return laneLayout; }
-  const LaneData &getData() const { return laneData; }
-  ArrayRef<int64_t> getLayoutAsArrayRef() const { return laneLayout.layout; }
-  ArrayRef<int64_t> getDataAsArrayRef() const { return laneData.layout; }
+  Attribute get() { return storage; }
 };
 
 void LayoutInfo::print(raw_ostream &os) const {
   if (isAssigned()) {
-    os << "lane_layout: ";
-    laneLayout.print(os);
-    os << ", lane_data: ";
-    laneData.print(os);
+    os << storage;
   } else {
     os << "Not assigned.";
   }
@@ -159,18 +185,30 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
   llvm_unreachable("Join should not be triggered by layout propagation.");
 }
 
-/// Get the transposed layout according to the given permutation.
-LayoutInfo
-LayoutInfo::getTransposedLayout(ArrayRef<int64_t> permutation) const {
+/// Construct a new layout with the transposed lane layout and lane data.
+LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   if (!isAssigned())
     return {};
-  LaneLayout newLayout;
-  LaneData newData;
+  // Check if the permutation is valid.
+  llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
+  bool hasDuplicates = seen.size() != permutation.size();
+  bool withinRange = llvm::all_of(permutation, [&](size_t idx) {
+    return idx >= 0 && idx < permutation.size();
+  });
+
+  if (!withinRange || hasDuplicates) {
+    assert(false && "Invalid permutation for transpose.");
+    return {};
+  }
+
+  SmallVector<int32_t> laneLayout;
+  SmallVector<int32_t> laneData;
   for (int64_t idx : permutation) {
-    newLayout.layout.push_back(laneLayout.layout[idx]);
-    newData.layout.push_back(laneData.layout[idx]);
+    laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+    laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
   }
-  return LayoutInfo(newLayout, newData);
+  return LayoutInfo(
+      xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
 }
 
 //===----------------------------------------------------------------------===//
@@ -190,13 +228,15 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// Helper Function to get the default layout for uniform values like constants.
 /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
-static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+                                           unsigned rank) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
-  if (rank == 1)
-    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
-                      LaneData({1}));
-  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
-                    LaneData({1, 1}));
+  if (rank == 1) {
+    return LayoutInfo(
+        xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+  }
+  return LayoutInfo(xegpu::LayoutAttr::get(
+      ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
 }
 
 /// Helper to get the default layout for a vector type.
@@ -209,14 +249,15 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(1);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
   // Packing factor is determined by the element type bitwidth.
   int packingFactor = 1;
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
     packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
-  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
-                    LaneData({1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
+                                           {1, xegpu::targetinfo::subgroupSize},
+                                           {1, packingFactor}));
 }
 
 /// Helper to get the default layout for a vector type.
@@ -229,7 +270,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(1);
+    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
 
@@ -238,16 +279,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
         bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
             ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
             : 1;
-    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
-                      LaneData({1, packingFactor}));
+    return LayoutInfo(xegpu::LayoutAttr::get(
+        tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
+        {1, packingFactor}));
   }
 
   int packingFactor =
       (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
           ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
           : 1;
-  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
-                    LaneData({1, packingFactor}));
+  return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(),
+                                           {1, xegpu::targetinfo::subgroupSize},
+                                           {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -261,15 +304,17 @@ static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
-  LaneLayout layout({1, xegpu::targetinfo::subgroupSize});
+  SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize});
   // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
   // must have the VNNI format.
   if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
                              xegpu::targetinfo::packedSizeInBitsForDpasB) {
-    LaneData data({xegpu::targetinfo::packedSizeInBitsForDpasB /
-                       elementTy.getIntOrFloatBitWidth(),
-                   1});
-    return LayoutInfo(layout, data);
+    SmallVector<int32_t, 2> data(
+        {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB /
+                              elementTy.getIntOrFloatBitWidth()),
+         1});
+    return LayoutInfo(
+        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
   return getDefaultSIMTLayoutInfo(vectorTy);
@@ -450,7 +495,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   }
   // Given that the result is 1D, the layout of the operand should be 2D with
   // default layout.
-  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
+  LayoutInfo operandLayout =
+      getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
   propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
   // Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -494,43 +540,55 @@ void LayoutInfoPropagation::visitShapeCastOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  VectorType sourceTy = shapeCast.getSourceVectorType();
-  VectorType resultTy = shapeCast.getResultVectorType();
+  int64_t sourceRank = shapeCast.getSourceVectorType().getRank();
+  int64_t resultRank = shapeCast.getResultVectorType().getRank();
   // Expecting source rank to be 1D or 2D.
-  if (sourceTy.getRank() != 1 && sourceTy.getRank() != 2) {
+  if (sourceRank != 1 && sourceRank != 2) {
     shapeCast.emitWarning("Expecting source type to be 1D or 2D vector.");
     return;
   }
   // Expecting result rank to be 1D or 2D.
-  if (resultTy.getRank() != 1 && resultTy.getRank() != 2) {
+  if (resultRank != 1 && resultRank != 2) {
     shapeCast.emitWarning("Expecting result type to be 1D or 2D vector.");
     return;
   }
   // For 2D -> 2D shape cast, propagate the result layout to the source.
-  if (sourceTy.getRank() == 2 && resultTy.getRank() == 2) {
-    // Propagate the result layout to the source operand.
+  if (sourceRank == 2 && resultRank == 2) {
     propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
     return;
   }
-  auto resultLayoutArray = resultLayout.getLayoutAsArrayRef();
-  if (resultLayoutArray[0] != 1 && resultLayoutArray[1] != 1) {
+  auto resultLaneLayout = resultLayout.getLaneLayout();
+  if (resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
     shapeCast.emitWarning(
         "Expecting result layout to be of form [1, subgroupSize] "
         "or [subgroupSize, 1].");
     return;
   }
-  int64_t distributedDim = resultLayoutArray[0] == 1 ? 1 : 0;
-  // If the result shape can be evenly distributed in the distributed dimension,
-  // then the source layout should be [subgroupSize][1]. Otherwise, data is
-  // shared accross lanes (broadcasted). In that case, just assign [1][1] for
-  // now (TODO: Use slice for this case)
-  LayoutInfo sourceLayout =
-      resultTy.getShape()[distributedDim] % xegpu::targetinfo::subgroupSize == 0
-          ? LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
-                       LaneData({1}))
-          : LayoutInfo(LaneLayout({1}), LaneData({1}));
-  // Propagate the source layout to the source operand.
-  propagateIfChanged(operands[0], operands[0]->meet(sourceLayout));
+  ArrayRef<int64_t> resultShape = shapeCast.getResultVectorType().getShape();
+  // For 2D -> 1D case, source gets the reusult's lane layout and lane data.
+  if (sourceRank == 2 && resultRank == 1) {
+    propagateIfChanged(operands[0],
+                       operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+                           shapeCast->getContext(), resultLaneLayout,
+                           resultLayout.getLaneData()))));
+    return;
+  }
+
+  // For 1D -> 2D case, If the result shape can be evenly distributed in the
+  // distributed dimension, then the source layout should be [subgroupSize][1].
+  // Otherwise, data is shared accross lanes (broadcasted). We use slice
+  // attribute for the broadcast case.
+  int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
+  xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
+      shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
+  if (resultShape[distributedDim] % xegpu::targetinfo::subgroupSize != 0) {
+    xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+        shapeCast->getContext(), plainLayout,
+        DenseI64ArrayAttr::get(shapeCast->getContext(), {distributedDim}));
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
+    return;
+  }
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
 }
 
 /// Propagate the layout of the result tensor to the source tensor descriptor in
@@ -591,7 +649,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp at "
                      "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout(transpose.value());
+    tensorDescLayout = valueLayout.transpose(transpose.value());
   }
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -606,8 +664,7 @@ void LayoutInfoPropagation::visitTransposeOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  LayoutInfo newLayout =
-      resultLayout.getTransposedLayout(transpose.getPermutation());
+  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
   // Propagate the new layout to the vector operand.
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
@@ -636,9 +693,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
   int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
                                  : outElemTyBitWidth / inElemTyBitWidth;
-  const LaneLayout &sourceLaneLayout =
-      resultLayout.getLayout(); // source lane layout is unchanged.
-  ArrayRef<int64_t> outData = resultLayout.getDataAsArrayRef();
+  ArrayRef<int> sourceLaneLayout =
+      resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
+  ArrayRef<int> outData = resultLayout.getLaneData();
 
   // TODO: Currently we assume that bitcasts does not require cross lane
   // communication. So each lane must own the required number of elements to
@@ -650,12 +707,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
     return;
   }
   // Check if each lane owns a single element in all dimensions except the
-  // innermost dimension. For example, if the result layout is [1, 16][2, 1], we
-  // are not allowed to bitcast such vectors.
-  // TODO: Relax this based on use cases.
-  SmallVector<int64_t, 3> sourceLaneDataStorage(outData.begin(),
-                                                outData.end() - 1);
-  if (llvm::any_of(sourceLaneDataStorage, [](int64_t d) { return d != 1; })) {
+  // innermost dimension.
+  SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
+  if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
     bitcast.emitWarning(
         "Each lane must not own multiple elements in any dimension other than "
         "the innermost dimension.");
@@ -664,11 +718,12 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   // Decide lane data based on whether the bitcast is narrowing or widening.
   int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
                                           : outData[rank - 1] * bitCastRatio;
-  sourceLaneDataStorage.push_back(innerMostLaneData);
-  LaneData sourceLaneData(sourceLaneDataStorage);
+  sourceLaneData.push_back(innerMostLaneData);
 
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(
-                                      sourceLaneLayout, sourceLaneData)));
+  propagateIfChanged(
+      operands[0],
+      operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+          bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
 }
 
 /// Propagate the layout of the result to the tensor descriptor and mask
@@ -680,7 +735,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
 
   // Mask operand should have 1D default layout.
-  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
 
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(layout));
@@ -698,7 +753,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
   if (!descLayout.isAssigned())
     return;
   // For offset operand propagate 1D default layout.
-  LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
   propagateIfChanged(operands[1], operands[1]->meet(layout));
 }
 
@@ -725,7 +780,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   // Propagate the tensor descriptor layout.
   propagateIfChanged(operands[1], operands[1]->meet(layout));
   // Use default 1D layout for mask operand.
-  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+  LayoutInfo maskLayout =
+      getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
 }
 
@@ -813,7 +869,7 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
     printFunctionResult(funcOp);
 }
 
-using GetLayoutFnTy = function_ref<xegpu::LayoutAttr(Value)>;
+using GetLayoutFnTy = function_ref<xegpu::LayoutTrait(Value)>;
 /// Update an operation with the layout of its results. If the result type is a
 /// vector type, a temporary layout attribute is added to the operation. If the
 /// result type is a tensor descriptor type, the type is updated with the layout
@@ -832,7 +888,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     if (!isa<VectorType, xegpu::TensorDescType>(resultType))
       continue;
     // If the result has no layout but has users, emit a warning and continue.
-    xegpu::LayoutAttr layout = getLayoutOfValue(result);
+    xegpu::LayoutTrait layout = getLayoutOfValue(result);
     if (!layout && result.getNumUses() > 0) {
       op->emitWarning("op has users but no layout assigned for its result");
       continue;
@@ -898,8 +954,9 @@ updateControlFlowOps(mlir::OpBuilder &builder,
       // We only need to operate on tensor descriptor or vector types.
       if (!isa<xegpu::TensorDescType, VectorType>(inputType))
         continue;
-      xegpu::LayoutAttr successorInputLayout = getLayoutOfValue(successorInput);
-      xegpu::LayoutAttr successorOperandLayout =
+      xegpu::LayoutTrait successorInputLayout =
+          getLayoutOfValue(successorInput);
+      xegpu::LayoutTrait successorOperandLayout =
           getLayoutOfValue(successorOperand);
 
       // If either of the layouts is not assigned, we cannot proceed.
@@ -947,7 +1004,7 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
     newArgTypes.push_back(argType);
     if (!isa<VectorType, xegpu::TensorDescType>(argType))
       continue;
-    xegpu::LayoutAttr layout = getLayoutOfValue(arg);
+    xegpu::LayoutTrait layout = getLayoutOfValue(arg);
     if (!layout) {
       LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
                         << " but got none.\n");
@@ -989,13 +1046,13 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     return;
   }
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
-  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutAttr {
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutTrait {
     LayoutInfo layout = analysis.getLayoutInfo(val);
     if (!layout.isAssigned())
       return {};
-    return xegpu::LayoutAttr::get(
-        val.getContext(), llvm::to_vector_of<int>(layout.getLayoutAsArrayRef()),
-        llvm::to_vector_of<int>(layout.getDataAsArrayRef()));
+    if (layout.isSliceLayout())
+      return cast<xegpu::SliceAttr>(layout.get());
+    return cast<xegpu::LayoutAttr>(layout.get());
   };
 
   mlir::OpBuilder builder(&getContext());

>From 5a683b443e3160c2c81449338beeedebfe6ac229 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Mon, 25 Aug 2025 23:53:39 +0000
Subject: [PATCH 10/15] save work

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 142 ++++++++++--------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  26 ++--
 2 files changed, 94 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 9a7c9570af6b6..0434566e21f4e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -110,14 +110,11 @@ namespace {
 
 struct LayoutInfo {
 private:
-  mlir::Attribute storage = nullptr;
+  xegpu::DistributeLayoutAttr storage = nullptr;
 
 public:
   LayoutInfo() = default;
-  LayoutInfo(const xegpu::LayoutAttr &layout) : storage(layout) {}
-  LayoutInfo(const xegpu::SliceAttr &slice) : storage(slice) {
-    storage = slice.flatten();
-  }
+  LayoutInfo(const xegpu::DistributeLayoutAttr &layout) : storage(layout) {}
 
   // Two lattice values are equal if they have `some` layout. The actual
   // content of the layout does not matter.
@@ -135,28 +132,26 @@ struct LayoutInfo {
 
   LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
 
-  ArrayRef<int> getLaneLayout() const {
+  SmallVector<int> getLaneLayout() const {
     if (!isAssigned())
       return {};
-    if (isa<xegpu::LayoutAttr>(storage))
-      return cast<xegpu::LayoutAttr>(storage).getLaneLayout().asArrayRef();
-    xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
-    assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
-           "Slice parent must be a LayoutAttr");
-    auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
-    return parent.getLaneLayout().asArrayRef();
+    assert(storage.getLaneLayoutAsInt().has_value() &&
+           "Expected lane layout to be assigned");
+    return llvm::map_to_vector(
+        storage.getLaneLayoutAsInt().value(),
+        [](int64_t val) { return static_cast<int>(val); });
   }
-  ArrayRef<int> getLaneData() const {
+
+  SmallVector<int> getLaneData() const {
     if (!isAssigned())
       return {};
-    if (isa<xegpu::LayoutAttr>(storage))
-      return cast<xegpu::LayoutAttr>(storage).getLaneData().asArrayRef();
-    xegpu::SliceAttr slice = cast<xegpu::SliceAttr>(storage);
-    assert(isa<xegpu::LayoutAttr>(slice.getParent()) &&
-           "Slice parent must be a LayoutAttr");
-    auto parent = cast<xegpu::LayoutAttr>(slice.getParent());
-    return parent.getLaneData().asArrayRef();
+    assert(storage.getLaneDataAsInt().has_value() &&
+           "Expected lane data to be assigned");
+    return llvm::map_to_vector(
+        storage.getLaneDataAsInt().value(),
+        [](int64_t val) { return static_cast<int>(val); });
   }
+
   bool isSliceLayout() const {
     if (!isAssigned())
       return false;
@@ -558,26 +553,49 @@ void LayoutInfoPropagation::visitShapeCastOp(
     return;
   }
   auto resultLaneLayout = resultLayout.getLaneLayout();
-  if (resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
+  if (resultRank == 2 && resultLaneLayout[0] != 1 && resultLaneLayout[1] != 1) {
     shapeCast.emitWarning(
-        "Expecting result layout to be of form [1, subgroupSize] "
+        "Expecting 2D result layout to be of form [1, subgroupSize] "
         "or [subgroupSize, 1].");
     return;
   }
   ArrayRef<int64_t> resultShape = shapeCast.getResultVectorType().getShape();
-  // For 2D -> 1D case, source gets the reusult's lane layout and lane data.
+  ArrayRef<int64_t> sourceShape = shapeCast.getSourceVectorType().getShape();
+  // For 2D -> 1D case.
   if (sourceRank == 2 && resultRank == 1) {
-    propagateIfChanged(operands[0],
-                       operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
-                           shapeCast->getContext(), resultLaneLayout,
-                           resultLayout.getLaneData()))));
-    return;
+    // If the result had slice layout, simply assign the parent layout of the
+    // slice.
+    if (resultLayout.isSliceLayout()) {
+      auto sliceAttr = cast<xegpu::SliceAttr>(resultLayout.get());
+      propagateIfChanged(operands[0],
+                         operands[0]->meet(LayoutInfo(sliceAttr.getParent())));
+      return;
+    }
+    // If the result has a regular 1D layout, then we find the first dimension
+    // that can be fully evenly distributed to lanes. This dimension becomes
+    // the distributed dimension for deciding the lane layout.
+    int sourceDistributedDim =
+        sourceShape[0] % xegpu::targetinfo::subgroupSize == 0
+            ? 0
+            : (sourceShape[1] % xegpu::targetinfo::subgroupSize ? 1 : -1);
+    if (sourceDistributedDim == -1) {
+      shapeCast.emitWarning(
+          "Source vector can not be evenly distributed across lanes.");
+      return;
+    }
+    SmallVector<int> sourceLaneLayout = {1, 1},
+                     laneData = {1, resultLayout.getLaneData()[0]};
+    sourceLaneLayout[sourceDistributedDim] = xegpu::targetinfo::subgroupSize;
+    propagateIfChanged(
+        operands[0],
+        operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
+            shapeCast->getContext(), sourceLaneLayout, laneData))));
   }
 
   // For 1D -> 2D case, If the result shape can be evenly distributed in the
-  // distributed dimension, then the source layout should be [subgroupSize][1].
-  // Otherwise, data is shared accross lanes (broadcasted). We use slice
-  // attribute for the broadcast case.
+  // distributed dimension, then the source layout should be
+  // [subgroupSize][1]. Otherwise, data is shared accross lanes (broadcasted).
+  // We use slice attribute for the broadcast case.
   int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
   xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
       shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
@@ -591,8 +609,8 @@ void LayoutInfoPropagation::visitShapeCastOp(
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
 }
 
-/// Propagate the layout of the result tensor to the source tensor descriptor in
-/// UpdateNdOffsetOp.
+/// Propagate the layout of the result tensor to the source tensor descriptor
+/// in UpdateNdOffsetOp.
 void LayoutInfoPropagation::visitUpdateNdOffsetOp(
     xegpu::UpdateNdOffsetOp updateNdOffset,
     ArrayRef<LayoutInfoLattice *> operands,
@@ -710,9 +728,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   // innermost dimension.
   SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
   if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
-    bitcast.emitWarning(
-        "Each lane must not own multiple elements in any dimension other than "
-        "the innermost dimension.");
+    bitcast.emitWarning("Each lane must not own multiple elements in any "
+                        "dimension other than "
+                        "the innermost dimension.");
     return;
   }
   // Decide lane data based on whether the bitcast is narrowing or widening.
@@ -869,15 +887,16 @@ void RunLayoutInfoPropagation::printAnalysisResult(llvm::raw_ostream &os) {
     printFunctionResult(funcOp);
 }
 
-using GetLayoutFnTy = function_ref<xegpu::LayoutTrait(Value)>;
-/// Update an operation with the layout of its results. If the result type is a
-/// vector type, a temporary layout attribute is added to the operation. If the
-/// result type is a tensor descriptor type, the type is updated with the layout
-/// attribute. The users of the result are also updated with the layout
+using GetLayoutFnTy = function_ref<xegpu::DistributeLayoutAttr(Value)>;
+/// Update an operation with the layout of its results. If the result type is
+/// a vector type, a temporary layout attribute is added to the operation. If
+/// the result type is a tensor descriptor type, the type is updated with the
+/// layout attribute. The users of the result are also updated with the layout
 /// attribute.
 static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
                               GetLayoutFnTy getLayoutOfValue) {
-  // Region ops (like scf.for) are already handled by the updateControlFlowOps.
+  // Region ops (like scf.for) are already handled by the
+  // updateControlFlowOps.
   if (mlir::isa<mlir::RegionBranchOpInterface>(op))
     return success();
 
@@ -888,7 +907,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     if (!isa<VectorType, xegpu::TensorDescType>(resultType))
       continue;
     // If the result has no layout but has users, emit a warning and continue.
-    xegpu::LayoutTrait layout = getLayoutOfValue(result);
+    xegpu::DistributeLayoutAttr layout = getLayoutOfValue(result);
     if (!layout && result.getNumUses() > 0) {
       op->emitWarning("op has users but no layout assigned for its result");
       continue;
@@ -910,14 +929,14 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 }
 
 /// Region ops like scf.for need special handling because they have blocks
-/// inside. If the blocks have tensor descriptor type as block arguments, thier
-/// types must be updated. Also region op can have results that may not have any
-/// users (e.g. A and B tiles). They are not assigned a layout by layout
-/// analysis because they have no users. However inside the region op
-/// corresponding block arguments for these results do have layouts. Therefore,
-/// in this case we still need to update the result types with the layout
-/// attribute. This function function updates the internal block arguments and
-/// the result types of the region op with the assigned layouts.
+/// inside. If the blocks have tensor descriptor type as block arguments,
+/// thier types must be updated. Also region op can have results that may not
+/// have any users (e.g. A and B tiles). They are not assigned a layout by
+/// layout analysis because they have no users. However inside the region op
+/// corresponding block arguments for these results do have layouts.
+/// Therefore, in this case we still need to update the result types with the
+/// layout attribute. This function function updates the internal block
+/// arguments and the result types of the region op with the assigned layouts.
 /// clang-format off
 /// Example: scf.for ... iter_args(...) -> (out types) {
 ///   ^bb0(block types):
@@ -929,8 +948,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 /// regions. One is the ^bb0 (for loop body) and the other is the scf.for op
 /// itself (yield the results). So we update both the block arguments of the
 /// successor region (i.e. block types) and the result types of the scf.for op
-/// (i.e. out types). Note that yield types are updated by respective producers
-/// inside bb0.
+/// (i.e. out types). Note that yield types are updated by respective
+/// producers inside bb0.
 static LogicalResult
 updateControlFlowOps(mlir::OpBuilder &builder,
                      mlir::RegionBranchTerminatorOpInterface terminator,
@@ -954,17 +973,16 @@ updateControlFlowOps(mlir::OpBuilder &builder,
       // We only need to operate on tensor descriptor or vector types.
       if (!isa<xegpu::TensorDescType, VectorType>(inputType))
         continue;
-      xegpu::LayoutTrait successorInputLayout =
+      xegpu::DistributeLayoutAttr successorInputLayout =
           getLayoutOfValue(successorInput);
-      xegpu::LayoutTrait successorOperandLayout =
+      xegpu::DistributeLayoutAttr successorOperandLayout =
           getLayoutOfValue(successorOperand);
 
       // If either of the layouts is not assigned, we cannot proceed.
       if (!successorOperandLayout) {
-        LLVM_DEBUG(
-            DBGS()
-            << "No layout assigned for forwarded operand in branch terminator: "
-            << successorOperand << "\n");
+        LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
+                             "branch terminator: "
+                          << successorOperand << "\n");
         return failure();
       }
       // We expect the layouts to match.
@@ -1004,7 +1022,7 @@ static LogicalResult updateFunctionOpInterface(mlir::OpBuilder &builder,
     newArgTypes.push_back(argType);
     if (!isa<VectorType, xegpu::TensorDescType>(argType))
       continue;
-    xegpu::LayoutTrait layout = getLayoutOfValue(arg);
+    xegpu::DistributeLayoutAttr layout = getLayoutOfValue(arg);
     if (!layout) {
       LLVM_DEBUG(DBGS() << "Expecting layout for function argument: " << arg
                         << " but got none.\n");
@@ -1046,7 +1064,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     return;
   }
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
-  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::LayoutTrait {
+  auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
     LayoutInfo layout = analysis.getLayoutInfo(val);
     if (!layout.isAssigned())
       return {};
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 27b8fc1c2919d..31821ee07d418 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -76,12 +76,12 @@ namespace {
 /// | 32x16                 | [2, 8]      | 16x2                     |
 /// | 2x32x16               | [1, 16]     | 2x32x1                   |
 static FailureOr<VectorType>
-getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
+getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
                                 VectorType originalType) {
   if (!layout)
     return failure();
 
-  auto laneLayout = layout.getLaneLayout().asArrayRef();
+  auto laneLayout = layout.getLaneLayoutAsInt().value();
   assert(originalType.getShape().size() >= laneLayout.size() &&
          "Rank of the original vector type should be greater or equal to the "
          "size of the lane layout to distribute the vector type.");
@@ -868,7 +868,7 @@ struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
     unsigned operandIdx = operand->getOperandNumber();
     VectorType distributedSourceType =
         getDistVecTypeBasedOnLaneLayout(
-            xegpu::getLayoutAttr(bitcastOp.getSource()),
+            xegpu::getDistributeLayoutAttr(bitcastOp.getSource()),
             bitcastOp.getSourceVectorType())
             .value_or(VectorType());
     if (!distributedSourceType)
@@ -907,24 +907,26 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
           warpOp, "warp result is not a vector::Transpose op");
     auto transposeOp = operand->get().getDefiningOp<vector::TransposeOp>();
     unsigned operandIdx = operand->getOperandNumber();
-    xegpu::LayoutAttr sourceLayout =
-        xegpu::getLayoutAttr(transposeOp.getVector());
-    xegpu::LayoutAttr resultLayout =
-        xegpu::getLayoutAttr(transposeOp.getResult());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(transposeOp.getVector());
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(transposeOp.getResult());
     if (!sourceLayout || !resultLayout)
       return rewriter.notifyMatchFailure(
           transposeOp,
           "the source or result vector of the transpose op lacks layout "
           "attribute");
-    ArrayRef<int> sourceLaneLayout = sourceLayout.getLaneLayout().asArrayRef();
-    ArrayRef<int> resultLaneLayout = resultLayout.getLaneLayout().asArrayRef();
-    ArrayRef<int> sourceLaneData = sourceLayout.getLaneData().asArrayRef();
-    ArrayRef<int> resultLaneData = resultLayout.getLaneData().asArrayRef();
+    ArrayRef<int64_t> sourceLaneLayout =
+        sourceLayout.getLaneLayoutAsInt().value();
+    ArrayRef<int64_t> resultLaneLayout =
+        resultLayout.getLaneLayoutAsInt().value();
+    ArrayRef<int64_t> sourceLaneData = sourceLayout.getLaneDataAsInt().value();
+    ArrayRef<int64_t> resultLaneData = resultLayout.getLaneDataAsInt().value();
     if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
       return rewriter.notifyMatchFailure(
           transposeOp, "the source or result vector of the transpose op "
                        "does not have 2D layout");
-    auto is2DTranspose = [](ArrayRef<int> input, ArrayRef<int> output) {
+    auto is2DTranspose = [](ArrayRef<int64_t> input, ArrayRef<int64_t> output) {
       return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
              input[1] == output[0];
     };

>From 2da2c6de6f3043462d871b8083a19e09738cc509 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 19:54:35 +0000
Subject: [PATCH 11/15] save work

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 19 ++---
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 80 ++++++++++++++++++-
 2 files changed, 88 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 0434566e21f4e..3f30751875679 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -187,8 +187,8 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   // Check if the permutation is valid.
   llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
   bool hasDuplicates = seen.size() != permutation.size();
-  bool withinRange = llvm::all_of(permutation, [&](size_t idx) {
-    return idx >= 0 && idx < permutation.size();
+  bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
+    return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
   });
 
   if (!withinRange || hasDuplicates) {
@@ -577,7 +577,7 @@ void LayoutInfoPropagation::visitShapeCastOp(
     int sourceDistributedDim =
         sourceShape[0] % xegpu::targetinfo::subgroupSize == 0
             ? 0
-            : (sourceShape[1] % xegpu::targetinfo::subgroupSize ? 1 : -1);
+            : (sourceShape[1] % xegpu::targetinfo::subgroupSize == 0 ? 1 : -1);
     if (sourceDistributedDim == -1) {
       shapeCast.emitWarning(
           "Source vector can not be evenly distributed across lanes.");
@@ -597,16 +597,17 @@ void LayoutInfoPropagation::visitShapeCastOp(
   // [subgroupSize][1]. Otherwise, data is shared accross lanes (broadcasted).
   // We use slice attribute for the broadcast case.
   int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
-  xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
-      shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
   if (resultShape[distributedDim] % xegpu::targetinfo::subgroupSize != 0) {
+    xegpu::LayoutAttr parentLayout = xegpu::LayoutAttr::get(
+        shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
     xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-        shapeCast->getContext(), plainLayout,
+        shapeCast->getContext(), parentLayout,
         DenseI64ArrayAttr::get(shapeCast->getContext(), {distributedDim}));
     propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
   }
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
+  propagateIfChanged(operands[0], operands[0]->meet(getDefaultSIMTLayoutInfo(
+                                      shapeCast.getSourceVectorType())));
 }
 
 /// Propagate the layout of the result tensor to the source tensor descriptor
@@ -711,9 +712,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
   int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
                                  : outElemTyBitWidth / inElemTyBitWidth;
-  ArrayRef<int> sourceLaneLayout =
+  SmallVector<int> sourceLaneLayout =
       resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
-  ArrayRef<int> outData = resultLayout.getLaneData();
+  SmallVector<int> outData = resultLayout.getLaneData();
 
   // TODO: Currently we assume that bitcasts does not require cross lane
   // communication. So each lane must own the required number of elements to
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 994fa44cab0b6..25d237c58e2ce 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -455,7 +455,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
 }
 
 // -----
-// CHECK-LABEL: func.func @test_scf_while_and_condition(
+// CHECK-LABEL: func.func @scf_while_and_condition(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
 // CHECK: %{{.*}}:3 = scf.while ({{.*}}) : (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>)
 // CHECK-SAME: -> (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
@@ -464,7 +464,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
 // CHECK-NEXT: ^bb0(%{{.*}}: vector<16xf32>, %{{.*}}: i32, %{{.*}}: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>):
 // CHECK:     scf.yield {{.*}} : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
 // CHECK-NEXT: } attributes {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
+func.func @scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
   %c0 = arith.constant 0 : i32
   %c16 = arith.constant 16 : i32
   %c256 = arith.constant 256 : i32
@@ -486,3 +486,79 @@ func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<25
   }
   return
 }
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim0_distributed(
+// CHECK-SAME:      %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>,
+// CHECK-SAME:      %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK:           %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]]
+// CHECK-SAME:        {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME:        !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x1xf16>
+// CHECK-NEXT:      %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME:        : vector<16x1xf16> to vector<16xf16>
+func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
+  %c0 = arith.constant 0 : index
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16>
+  %2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16>
+  xegpu.store_nd %2, %arg1  : vector<16xf16>, !xegpu.tensor_desc<16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim1_distributed(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
+// CHECK:         %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:        !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<1x16xf16>
+// CHECK:         %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK-SAME:        vector<1x16xf16> to vector<16xf16>
+func.func @vector_shape_cast_2d_to_1d_dim1_distributed(%arg0: !xegpu.tensor_desc<1x16xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
+  %c0 = arith.constant 0 : index
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16>
+  %2 = vector.shape_cast %3 : vector<1x16xf16> to vector<16xf16>
+  xegpu.store_nd %2, %arg1  : vector<16xf16>, !xegpu.tensor_desc<16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:         %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:       vector<16xf16> to vector<1x16xf16>
+func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+  %2 = vector.shape_cast %4 : vector<16xf16> to vector<1x16xf16>
+  %5 = vector.broadcast %2 : vector<1x16xf16> to vector<16x16xf16>
+  xegpu.store_nd %5, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+
+// -----
+// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(
+// CHECK-SAME:     %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME:     %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:          %[[LOAD:.*]] = xegpu.load_nd %arg0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:        !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:     %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME:        {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1]
+// CHECK-SAME:        vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:     %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:        vector<16xf16> to vector<16x1xf16>
+func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+  %2 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+  %5 = vector.broadcast %2 : vector<16x1xf16> to vector<16x16xf16>
+  xegpu.store_nd %5, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}

>From 7eabad47a70eaac1c15207a62d844f01c4205b62 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:10:04 +0000
Subject: [PATCH 12/15] save work

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  3 +
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 96 +++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 31821ee07d418..3e67e6406b956 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -827,6 +827,9 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Sink a memref::ExtractAlignedPointerAsIndex op feeding into yield op of an
+/// enclosing `gpu.warp_execute_on_lane_0` region. This will simply move the op
+/// outside of the warp op.
 struct MemrefExtractAlignedPointerAsIndexDistribution final
     : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 54ef56e013abb..690b13f5a2973 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -319,3 +319,99 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @memref_extract_aligned_pointer_as_index(
+// CHECK:  %{{.*}} = memref.extract_aligned_pointer_as_index %{{.*}} : memref<256x256xf16> -> index
+gpu.module @test {
+  gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1.000000e+00> : vector<16xf16>
+    %ptr = memref.extract_aligned_pointer_as_index %arg0 : memref<256x256xf16> -> index
+    %ptr_i64 = arith.index_cast %ptr : index to i64
+    %tdesc = xegpu.create_nd_tdesc %ptr_i64[%c0], shape: [16], strides: [16] : i64
+      -> !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    xegpu.store_nd %cst, %tdesc : vector<16xf16>, !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
+    gpu.return
+  }
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @vector_transpose(
+// CHECK:         %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<2xf32>
+// CHECK:         %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<2x16xf32> -> !xegpu.tensor_desc<2x16xf32>
+// CHECK:         xegpu.store_nd %[[CST]], %[[DEST]]  : vector<2xf32>, !xegpu.tensor_desc<2x16xf32>
+gpu.module @test {
+  gpu.func @vector_transpose(%arg0: memref<2x16xf32>) {
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} dense<1.000000e+00>
+      : vector<16x2xf32>
+    %c0 = arith.constant 0 : index
+    %transpose = vector.transpose %cst, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<16x2xf32> to vector<2x16xf32>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<2x16xf32>
+      -> !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %transpose, %0 : vector<2x16xf32>,
+      !xegpu.tensor_desc<2x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_bitcast(
+// CHECK:      %[[CAST:.*]] = vector.bitcast %{{.*}} : vector<4x2xi8> to vector<4x1xi16>
+// CHECK-NEXT: %[[DEST:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<4x16xi16> -> !xegpu.tensor_desc<4x16xi16>
+// CHECK-NEXT: %[[T0:.*]] = vector.shape_cast %[[CAST]] : vector<4x1xi16> to vector<4xi16>
+// CHECK-NEXT: xegpu.store_nd %[[T0]], %[[DEST]]  : vector<4xi16>, !xegpu.tensor_desc<4x16xi16>
+gpu.module @test {
+  gpu.func @vector_bitcast(%arg0: memref<4x16xi16>) {
+    %cst = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
+      : () -> (vector<4x32xi8>)
+    %bitcast = vector.bitcast %cst {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<4x32xi8> to vector<4x16xi16>
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x16xi16>
+      -> !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %bitcast, %0 : vector<4x16xi16>,
+      !xegpu.tensor_desc<4x16xi16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @mma_transpose_b(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK:         %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-NEXT:    %[[A:.*]] = xegpu.load_nd %[[ADESC]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-NEXT:    %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+// CHECK-NEXT:    %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
+// CHECK-NEXT:    %[[BCAST0:.*]] = vector.shape_cast %[[B]] : vector<8xi32> to vector<1x8xi32>
+// CHECK-NEXT:    %[[BCAST1:.*]] = vector.bitcast %[[BCAST0]] : vector<1x8xi32> to vector<1x16xf16>
+// CHECK-NEXT:    %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[C:.*]] = xegpu.dpas %[[A]], %[[BCAST2]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
+gpu.module @test {
+  gpu.func @mma_transpose_b(%arg0: memref<8x16xf16>, %arg1: memref<16x8xi32>, %arg2: memref<8x16xf32>) {
+    %c0 = arith.constant 0 : index
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16>
+      -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x8xi32>
+      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+    %3 = xegpu.load_nd %2  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+    %4 = vector.bitcast %3 {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+      : vector<16x8xi32> to vector<16x16xf16>
+    %5 = vector.transpose %4, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+      : vector<16x16xf16> to vector<16x16xf16>
+    %6 = xegpu.dpas %1, %5 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+    %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32>
+      -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %6, %7 : vector<8x16xf32>,
+      !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+
+  }
+}

>From 635a00679d1287f23a18594d8643811bbc6297f5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:20:54 +0000
Subject: [PATCH 13/15] save work

---
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 25d237c58e2ce..29592ec76f918 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -224,7 +224,7 @@ func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16
 
 // -----
 // CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
-// CHECK-NOT: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} {layout_result_0 = {{.*}}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
+// CHECK:     %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
 // CHECK:     %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 // CHECK-SAME:     vector<8x16xi32> to vector<8x32xi16>
 func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {

>From 74ab5a37ee0acee4d564c5eecb1fb0b564a5157b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 26 Aug 2025 23:38:59 +0000
Subject: [PATCH 14/15] save work

---
 mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 690b13f5a2973..8ecd080c96922 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -382,10 +382,10 @@ gpu.module @test {
 // CHECK-LABEL: gpu.func @mma_transpose_b(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
 // CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK:         %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-// CHECK-NEXT:    %[[A:.*]] = xegpu.load_nd %[[ADESC]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
-// CHECK-NEXT:    %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
-// CHECK-NEXT:    %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
+// CHECK-DAG:     %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+// CHECK-DAG:     %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
+// CHECK-DAG:     %[[A:.*]] = xegpu.load_nd %[[ADESC]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>
+// CHECK-DAG:     %[[B:.*]] = xegpu.load_nd %[[BDESC]] <{transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xi32> -> vector<8xi32>
 // CHECK-NEXT:    %[[BCAST0:.*]] = vector.shape_cast %[[B]] : vector<8xi32> to vector<1x8xi32>
 // CHECK-NEXT:    %[[BCAST1:.*]] = vector.bitcast %[[BCAST0]] : vector<1x8xi32> to vector<1x16xf16>
 // CHECK-NEXT:    %[[BCAST2:.*]] = vector.shape_cast %[[BCAST1]] : vector<1x16xf16> to vector<16xf16>

>From b36e109eb628a9262000ddfe4eb5e9c1e0d9bc5b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 27 Aug 2025 00:00:59 +0000
Subject: [PATCH 15/15] save work

---
 mlir/test/Dialect/XeGPU/subgroup-distribute.mlir | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8ecd080c96922..d2af6d064bb03 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -380,8 +380,7 @@ gpu.module @test {
 
 // -----
 // CHECK-LABEL: gpu.func @mma_transpose_b(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>,
-// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
+// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x8xi32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
 // CHECK-DAG:     %[[ADESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
 // CHECK-DAG:     %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}] : memref<16x8xi32> -> !xegpu.tensor_desc<16x8xi32>
 // CHECK-DAG:     %[[A:.*]] = xegpu.load_nd %[[ADESC]]  : !xegpu.tensor_desc<8x16xf16> -> vector<8xf16>



More information about the Mlir-commits mailing list