[Mlir-commits] [mlir] [mlir][xegpu] Add OptimizeTranspose pass. (PR #165483)

Charitha Saumya llvmlistbot at llvm.org
Tue Oct 28 17:13:03 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/165483

>From 9d0341dae9da51c916405c6e14adf4dac9d20af5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 21 Oct 2025 18:49:43 +0000
Subject: [PATCH 1/8] add pass

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   | 10 +++
 .../Dialect/XeGPU/Transforms/Transforms.h     |  3 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |  1 +
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 69 +++++++++++++++++++
 4 files changed, 82 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..d0185159f16ac 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -80,4 +80,14 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
                            "scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
 }
 
+def XeGPUOptimizeTranspose : Pass<"xegpu-optimize-transpose"> {
+  let summary = "Optimize XeGPU loadNd operations feeding into vector.transpose";
+  let description = [{
+    This pass rewrites XeGPU loadNd operations that feed into vector.transpose
+    into more optimal forms to improve performance.
+  }];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index a480195eebd00..cab1598457cbe 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -61,7 +61,8 @@ struct UnrollOptions {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-
+/// Appends patterns for optimizing transpose operations into `patterns`.
+void populateXeGPUOptimizeTransposePatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 /// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..a2ecbc3374ba9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUWgToSgDistribute.cpp
   XeGPUPropagateLayout.cpp
   XeGPUVectorLinearize.cpp
+  XeGPUOptimizeTranspose.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
new file mode 100644
index 0000000000000..69ba9e6fca7ba
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -0,0 +1,69 @@
+//===- XeGPUOptimizeTranspose.cpp - XeGPU optimize transpose ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZETRANSPOSE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-transpose"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return success();
+  }
+};
+} // namespace
+
+void xegpu::populateXeGPUOptimizeTransposePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<XeGPULoadNdPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeTransposePass final
+    : public xegpu::impl::XeGPUOptimizeTransposeBase<
+          XeGPUOptimizeTransposePass> {
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns)))) {
+      DBGS() << "Optimize transpose pass failed.\n";
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace

>From 76f7323b526079c585dda525d2fa7b3974ed4a49 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 22 Oct 2025 15:20:08 +0000
Subject: [PATCH 2/8] save work

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 62 +++++++++++++++++--
 .../Dialect/XeGPU/optimize-transpose.mlir     | 18 ++++++
 2 files changed, 76 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/optimize-transpose.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 69ba9e6fca7ba..83edbad3afce2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -12,8 +12,10 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <optional>
 
 namespace mlir {
 namespace xegpu {
@@ -29,12 +31,28 @@ using namespace mlir;
 
 namespace {
 
-class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
+static std::optional<SmallVector<int64_t>>
+get2DLaneData(xegpu::TensorDescType tdescType) {
+  auto layout = tdescType.getLayoutAttr();
+  if (!layout)
+    return std::nullopt;
+  auto laneData = layout.getEffectiveLaneDataAsInt();
+  if (laneData.size() != 2)
+    return std::nullopt;
+  return laneData;
+}
+
+class XeGPUCreateNdDescOpPattern final
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
 public:
-  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
+  matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto tdescTy = createNdOp.getType();
+    auto convertType = this->getTypeConverter()->convertType(tdescTy);
+    if (convertType == tdescTy)
+      return failure();
     return success();
   }
 };
@@ -42,7 +60,7 @@ class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
 
 void xegpu::populateXeGPUOptimizeTransposePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<XeGPULoadNdPattern>(patterns.getContext());
+  patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
 }
 
 namespace {
@@ -55,6 +73,42 @@ struct XeGPUOptimizeTransposePass final
     TypeConverter converter;
     RewritePatternSet patterns(&context);
     ConversionTarget target(context);
+
+    target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+        [](xegpu::CreateNdDescOp createNdOp) {
+          auto optionalLaneData = get2DLaneData(createNdOp.getType());
+          if (!optionalLaneData)
+            return true;
+          auto laneData = optionalLaneData.value();
+          return laneData[0] != 1 || laneData[1] == 1;
+        });
+
+    converter.addConversion([](xegpu::TensorDescType tdescType) {
+      auto optionalLaneData = get2DLaneData(tdescType);
+      if (!optionalLaneData)
+        return tdescType;
+      auto laneData = optionalLaneData.value();
+      int64_t innerLaneData = laneData[1];
+      if (laneData[0] == 1 && innerLaneData != 1) {
+        int elementTyBitwidth =
+            tdescType.getElementType().getIntOrFloatBitWidth();
+        assert(elementTyBitwidth < 32 &&
+               "Expected element type bitwidth < 32 with laneData[1] != 1");
+        SmallVector<int64_t> newShape(tdescType.getShape());
+        newShape.back() = newShape.back() / innerLaneData;
+        Type newElemTy = IntegerType::get(tdescType.getContext(),
+                                          elementTyBitwidth * innerLaneData);
+        xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+            tdescType.getContext(),
+            tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+        return xegpu::TensorDescType::get(
+            newShape, newElemTy, tdescType.getArrayLength(),
+            tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
+            newLayout);
+      }
+      return tdescType;
+    });
+
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
new file mode 100644
index 0000000000000..529211e286902
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -0,0 +1,18 @@
+func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+  %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+  %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+    %5 = xegpu.load_nd %2[%c0, %arg3]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+    %6 = xegpu.load_nd %3[%c0, %arg3]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
+    %7 = vector.transpose %6, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
+    %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %8 : vector<8x16xf32>
+  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+  xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  return
+}

>From 43c35bef094cd8754f332be17689dbef9e1c8516 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 22 Oct 2025 22:35:50 +0000
Subject: [PATCH 3/8] add some tests

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 188 +++++++++++++++---
 .../Dialect/XeGPU/optimize-transpose.mlir     | 119 ++++++++++-
 2 files changed, 264 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 83edbad3afce2..9792da7d6ce79 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -6,13 +6,19 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include <optional>
@@ -42,6 +48,46 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
   return laneData;
 }
 
+static xegpu::TensorDescType
+getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
+  auto optionalLaneData = get2DLaneData(tdescType);
+  if (!optionalLaneData)
+    return tdescType;
+  auto laneData = optionalLaneData.value();
+  int64_t innerLaneData = laneData[1];
+  if (laneData[0] == 1 && innerLaneData != 1) {
+    int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+    assert(elementTyBitwidth < 32 &&
+           "Expected element type bitwidth < 32 with laneData[1] != 1");
+    SmallVector<int64_t> newShape(tdescType.getShape());
+    newShape.back() = newShape.back() / innerLaneData;
+    Type newElemTy = IntegerType::get(tdescType.getContext(),
+                                      elementTyBitwidth * innerLaneData);
+    xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+        tdescType.getContext(),
+        tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+    return xegpu::TensorDescType::get(
+        newShape, newElemTy, tdescType.getArrayLength(),
+        tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
+  }
+  return tdescType;
+}
+
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+                            OpFoldResult ofr) {
+  std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+  if (mayBeInt)
+    return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
+  return llvm::cast<Value>(ofr);
+}
+
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+                              Value val, int64_t constant) {
+  auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
+  return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
+      .getResult();
+}
+
 class XeGPUCreateNdDescOpPattern final
     : public OpConversionPattern<xegpu::CreateNdDescOp> {
 public:
@@ -50,17 +96,106 @@ class XeGPUCreateNdDescOpPattern final
   matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto tdescTy = createNdOp.getType();
-    auto convertType = this->getTypeConverter()->convertType(tdescTy);
+    auto convertType = getModifiedTensorDescType(tdescTy);
     if (convertType == tdescTy)
       return failure();
+    auto strides = createNdOp.getMixedStrides();
+    auto maybeConstInnerStride = getConstantIntValue(strides.back());
+    // Only row-major memrefs are expected for now.
+    if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+      return failure();
+    Value source = createNdOp.getSource();
+    auto optionalLaneData = get2DLaneData(tdescTy);
+    assert(optionalLaneData && "Expected 2D lane data");
+    auto laneData = optionalLaneData.value();
+    int64_t innerLaneData = laneData[1];
+    auto memrefType = dyn_cast<MemRefType>(source.getType());
+    // Inner dimension of the shape must be adjusted based on innerLaneData.
+    SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+    modifiedShape.back() = divideByConstant(
+        rewriter, createNdOp.getLoc(),
+        convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+        innerLaneData);
+    // Similarly, second to last stride must be adjusted.
+    assert(strides.size() >= 2 &&
+           "Expected at least 2 strides for CreateNdDescOp");
+    SmallVector<OpFoldResult> modifiedStrides(strides);
+    modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+        rewriter, createNdOp.getLoc(),
+        convertToValue(rewriter, createNdOp.getLoc(),
+                       modifiedStrides[modifiedStrides.size() - 2]),
+        innerLaneData);
+
+    // If the source is a static memref, we need to extract the pointer to
+    // base address.
+    if (memrefType && memrefType.hasStaticShape()) {
+      auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+          rewriter, createNdOp.getLoc(), source);
+      source = arith::IndexCastOp::create(
+                   rewriter, createNdOp.getLoc(),
+                   IntegerType::get(rewriter.getContext(), 64),
+                   extractOp.getResult())
+                   .getResult();
+    }
+    // Create a new CreateNdDescOp with the modified shape and converted type.
+    auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+        rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+        modifiedStrides);
+    rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
     return success();
   }
 };
 } // namespace
 
+class XeGPULoadNdDescOpPattern final
+    : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto origTensorDescType = loadNdOp.getTensorDescType();
+    auto adaptorType =
+        cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+    if (adaptorType == origTensorDescType)
+      return failure();
+    // Offsets must be adjusted based on innerLaneData.
+    auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
+    assert(optionalLaneData && "Expected 2D lane data");
+    int64_t innerLaneData = optionalLaneData.value()[1];
+    auto offsets = loadNdOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(loadNdOp,
+                                         "Expecting offsets in LoadNd");
+    SmallVector<OpFoldResult> modifiedOffsets(offsets);
+    modifiedOffsets.back() = divideByConstant(
+        rewriter, loadNdOp.getLoc(),
+        convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+        innerLaneData);
+    VectorType modifiedType =
+        VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
+    // Create a new LoadNdOp with modified offsets and type.
+    auto newLoadNdOp = xegpu::LoadNdOp::create(
+        rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
+        modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
+        loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
+        loadNdOp.getL3HintAttr());
+    // Bitcast back to the original type.
+    auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                            loadNdOp.getType(), newLoadNdOp);
+    // Cast op must have the same layout as the original LoadNdOp result.
+    xegpu::setDistributeLayoutAttr(
+        castOp->getOpResult(0),
+        xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
+    rewriter.replaceOp(loadNdOp, castOp.getResult());
+    return success();
+  }
+};
+
 void xegpu::populateXeGPUOptimizeTransposePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
+  patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
+      patterns.getContext());
 }
 
 namespace {
@@ -74,41 +209,28 @@ struct XeGPUOptimizeTransposePass final
     RewritePatternSet patterns(&context);
     ConversionTarget target(context);
 
+    auto checkValidInnerLaneData =
+        [](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
+      if (!optionalLaneData)
+        return true;
+      auto laneData = optionalLaneData.value();
+      return laneData[0] != 1 || laneData[1] == 1;
+    };
+
     target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
-        [](xegpu::CreateNdDescOp createNdOp) {
+        [&](xegpu::CreateNdDescOp createNdOp) {
           auto optionalLaneData = get2DLaneData(createNdOp.getType());
-          if (!optionalLaneData)
-            return true;
-          auto laneData = optionalLaneData.value();
-          return laneData[0] != 1 || laneData[1] == 1;
+          return checkValidInnerLaneData(optionalLaneData);
         });
+    target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+        [&](xegpu::LoadNdOp loadNdOp) {
+          auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
+          return checkValidInnerLaneData(optionalLaneData);
+        });
+    converter.addConversion([](Type type) { return type; });
 
-    converter.addConversion([](xegpu::TensorDescType tdescType) {
-      auto optionalLaneData = get2DLaneData(tdescType);
-      if (!optionalLaneData)
-        return tdescType;
-      auto laneData = optionalLaneData.value();
-      int64_t innerLaneData = laneData[1];
-      if (laneData[0] == 1 && innerLaneData != 1) {
-        int elementTyBitwidth =
-            tdescType.getElementType().getIntOrFloatBitWidth();
-        assert(elementTyBitwidth < 32 &&
-               "Expected element type bitwidth < 32 with laneData[1] != 1");
-        SmallVector<int64_t> newShape(tdescType.getShape());
-        newShape.back() = newShape.back() / innerLaneData;
-        Type newElemTy = IntegerType::get(tdescType.getContext(),
-                                          elementTyBitwidth * innerLaneData);
-        xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
-            tdescType.getContext(),
-            tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
-        return xegpu::TensorDescType::get(
-            newShape, newElemTy, tdescType.getArrayLength(),
-            tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
-            newLayout);
-      }
-      return tdescType;
-    });
-
+    target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+                           vector::VectorDialect>();
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 529211e286902..0bc74d75a4ad0 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -1,18 +1,117 @@
+
+
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+  %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  return %6 : vector<8x16xf32>
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+func.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
+  %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  return %6 : vector<8x16xi32>
+}
+
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
 func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
   %c0 = arith.constant 0 : index
   %c16 = arith.constant 16 : index
   %c256 = arith.constant 256 : index
-  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %1 = xegpu.load_nd %0[%c0, %c0]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
-  %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
   %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
-    %5 = xegpu.load_nd %2[%c0, %arg3]  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-    %6 = xegpu.load_nd %3[%c0, %arg3]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
-    %7 = vector.transpose %6, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
-    %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+    %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
     scf.yield %8 : vector<8x16xf32>
-  } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-  xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  } {layout_result_0 = #a}
+  xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  return
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+  %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+    -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+    %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %11 = vector.transpose %7, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %12 = vector.transpose %8, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %13 = vector.transpose %9, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %14 = vector.transpose %10, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+  } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+  xegpu.store_nd %4#0, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#1, %0[%c0, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#2, %0[%c16, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  return
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  scf.for %arg8 = %c0 to %c256 step %c16 {
+    %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+    %1 = xegpu.load_nd %0[%arg8, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+    %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+    %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+    %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+      %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+      %6 = xegpu.load_nd %3[%arg8, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+      %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+      %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      scf.yield %8 : vector<8x16xf32>
+    } {layout_result_0 = #a}
+    xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  }
   return
 }

>From f79d2a2b68fc337fe431196c08c7b3e1b52a0256 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 23 Oct 2025 17:46:55 +0000
Subject: [PATCH 4/8] add some tests

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 15 ++++++
 .../Dialect/XeGPU/optimize-transpose.mlir     | 50 +++++++++----------
 2 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 9792da7d6ce79..e0a9e3c221400 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -35,6 +35,21 @@ namespace xegpu {
 
 using namespace mlir;
 
+struct TransposableBlockRange {
+  int minWidth, maxWidth, minHeight, maxHeight;
+};
+
+// TODO: Use uArch to get supported block ranges.
+static TransposableBlockRange getBlockRange(int bitWidth) {
+  switch (bitWidth) {
+  case 32:
+    return {/**min width**/ 1, /**max width**/ 8, /**min height**/ 1,
+            /**max height**/ 32};
+  default:
+    llvm_unreachable("Add support for other element bitwidths");
+  }
+}
+
 namespace {
 
 static std::optional<SmallVector<int64_t>>
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 0bc74d75a4ad0..35c9294c32f04 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -51,6 +51,31 @@ func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf1
   return
 }
 
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  scf.for %arg8 = %c0 to %c256 step %c16 {
+    %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+    %1 = xegpu.load_nd %0[%arg8, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+    %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+    %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+    %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+      %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+      %6 = xegpu.load_nd %3[%arg8, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+      %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+      %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      scf.yield %8 : vector<8x16xf32>
+    } {layout_result_0 = #a}
+    xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  }
+  return
+}
+
 // -----
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
@@ -90,28 +115,3 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
   xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
   return
 }
-
-// -----
-#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
-  %c0 = arith.constant 0 : index
-  %c16 = arith.constant 16 : index
-  %c256 = arith.constant 256 : index
-  scf.for %arg8 = %c0 to %c256 step %c16 {
-    %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
-    %1 = xegpu.load_nd %0[%arg8, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
-    %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
-    %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
-    %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
-      %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
-      %6 = xegpu.load_nd %3[%arg8, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
-      %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
-      %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-      scf.yield %8 : vector<8x16xf32>
-    } {layout_result_0 = #a}
-    xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
-  }
-  return
-}

>From ca5d9024135f3e49c79e3dc503cd715bca4dc38d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 24 Oct 2025 23:37:14 +0000
Subject: [PATCH 5/8] save work

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 265 +++++++++++++-----
 .../Dialect/XeGPU/optimize-transpose.mlir     |  44 +++
 2 files changed, 246 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index e0a9e3c221400..a697c0433ed56 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -9,18 +9,23 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Types.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
 #include <optional>
 
 namespace mlir {
@@ -35,12 +40,14 @@ namespace xegpu {
 
 using namespace mlir;
 
-struct TransposableBlockRange {
-  int minWidth, maxWidth, minHeight, maxHeight;
+namespace {
+
+struct Allowed2DShapeRange {
+  int64_t minWidth, maxWidth, minHeight, maxHeight;
 };
 
 // TODO: Use uArch to get supported block ranges.
-static TransposableBlockRange getBlockRange(int bitWidth) {
+static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
   switch (bitWidth) {
   case 32:
     return {/**min width**/ 1, /**max width**/ 8, /**min height**/ 1,
@@ -50,10 +57,8 @@ static TransposableBlockRange getBlockRange(int bitWidth) {
   }
 }
 
-namespace {
-
 static std::optional<SmallVector<int64_t>>
-get2DLaneData(xegpu::TensorDescType tdescType) {
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
   auto layout = tdescType.getLayoutAttr();
   if (!layout)
     return std::nullopt;
@@ -63,44 +68,131 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
   return laneData;
 }
 
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+  auto layout = tdescType.getLayoutAttr();
+  if (!layout)
+    return std::nullopt;
+  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+  if (laneLayout.size() != 2)
+    return std::nullopt;
+  return laneLayout;
+}
+
+// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
+// lane[1] == 1), but inner lane data is not equal to [1, 1].
+static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
+  // If the dtype is greater or equal to 32 bits, layout must be valid.
+  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+  if (elementTyBitwidth >= 32)
+    return false;
+  auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+  auto maybeLaneData = getMaybeLaneData(tdescType);
+  if (!maybeLaneData || !maybeLaneLayout)
+    return false;
+  auto laneData = maybeLaneData.value();
+  auto laneLayout = maybeLaneLayout.value();
+  if (laneLayout[0] == 1 || laneLayout[1] != 1)
+    return false;
+  if (laneData[0] != 1 || laneData[1] == 1)
+    return false;
+  return true;
+}
+
 static xegpu::TensorDescType
-getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
-  auto optionalLaneData = get2DLaneData(tdescType);
-  if (!optionalLaneData)
+tryConvertToTransposable(xegpu::TensorDescType tdescType) {
+  if (!hasInvalidTranposeLayout(tdescType))
     return tdescType;
-  auto laneData = optionalLaneData.value();
+  auto laneData = getMaybeLaneData(tdescType).value();
   int64_t innerLaneData = laneData[1];
-  if (laneData[0] == 1 && innerLaneData != 1) {
-    int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
-    assert(elementTyBitwidth < 32 &&
-           "Expected element type bitwidth < 32 with laneData[1] != 1");
-    SmallVector<int64_t> newShape(tdescType.getShape());
-    newShape.back() = newShape.back() / innerLaneData;
-    Type newElemTy = IntegerType::get(tdescType.getContext(),
-                                      elementTyBitwidth * innerLaneData);
-    xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
-        tdescType.getContext(),
-        tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
-    return xegpu::TensorDescType::get(
-        newShape, newElemTy, tdescType.getArrayLength(),
-        tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
-  }
-  return tdescType;
+  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+  // Required shape is total shape of the vector result that this tensor desc
+  // must eventually load after adjusting for the new bitwidth and array
+  // length.
+  SmallVector<int64_t> requiredShape(tdescType.getShape());
+  requiredShape.back() =
+      requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+  int newBitWidth = elementTyBitwidth * innerLaneData;
+  Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+  // Supported shape is the max transpose shape that can be supported by
+  // hardware that is less than or equal to required shape.
+  auto supportedHeight = std::min(
+      requiredShape[0], getTransposableBlockRange(newBitWidth).maxHeight);
+  auto supportedWidth = std::min(
+      requiredShape[1], getTransposableBlockRange(newBitWidth).maxWidth);
+  SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+
+  // Required shape must be multiple of supported shape. Otherwise, we can not
+  // optimize it.
+  // TODO: Supported shape can be adjusted to handle non-multiple cases.
+  if (requiredShape[0] % supportedShape[0] != 0 ||
+      requiredShape[1] % supportedShape[1] != 0)
+    return tdescType;
+
+  xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+      tdescType.getContext(),
+      tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+  // Array length can not be larger than 1 for transpose case.
+  return xegpu::TensorDescType::get(
+      supportedShape, newElemTy, /**array length**/ 1,
+      tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
+}
+
+static Value createConstantIndex(ConversionPatternRewriter &rewriter,
+                                 Location loc, int64_t value) {
+  return arith::ConstantIndexOp::create(rewriter, loc, value).getResult();
 }
 
 static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
                             OpFoldResult ofr) {
   std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
   if (mayBeInt)
-    return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
+    return createConstantIndex(rewriter, loc, *mayBeInt);
   return llvm::cast<Value>(ofr);
 }
 
 static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
                               Value val, int64_t constant) {
-  auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
-  return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
-      .getResult();
+  auto constantOp = createConstantIndex(rewriter, loc, constant);
+  return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+                           TypedValue<VectorType> data,
+                           SmallVector<int64_t> &shapeRatio,
+                           SmallVector<OpFoldResult> offsets,
+                           SmallVector<int64_t> &supportedShape,
+                           TypedValue<xegpu::TensorDescType> newTensorDesc,
+                           xegpu::LoadNdOp origLoadOp) {
+  Location loc = data.getLoc();
+  assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+  Value offsetX = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+  Value offsetY = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+  for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+    for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+      int64_t localOffsetX = h * supportedShape[0];
+      int64_t localOffsetY = w * supportedShape[1];
+      Value loadOffsetX = arith::AddIOp::create(
+          rewriter, loc, offsetX,
+          createConstantIndex(rewriter, loc, localOffsetX));
+      Value loadOffsetY = arith::AddIOp::create(
+          rewriter, loc, offsetY,
+          createConstantIndex(rewriter, loc, localOffsetY));
+      auto loadOp = xegpu::LoadNdOp::create(
+          rewriter, loc,
+          VectorType::get(supportedShape, data.getType().getElementType()),
+          newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+          origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+          origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+          origLoadOp.getL3HintAttr());
+      // Insert the loaded block into the right position in data.
+      data = vector::InsertStridedSliceOp::create(
+          rewriter, loc, loadOp.getResult(), data,
+          ArrayRef<int64_t>{localOffsetX, localOffsetY},
+          ArrayRef<int64_t>{1, 1});
+    }
+  }
+  return data;
 }
 
 class XeGPUCreateNdDescOpPattern final
@@ -111,7 +203,7 @@ class XeGPUCreateNdDescOpPattern final
   matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto tdescTy = createNdOp.getType();
-    auto convertType = getModifiedTensorDescType(tdescTy);
+    auto convertType = tryConvertToTransposable(tdescTy);
     if (convertType == tdescTy)
       return failure();
     auto strides = createNdOp.getMixedStrides();
@@ -120,7 +212,7 @@ class XeGPUCreateNdDescOpPattern final
     if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
       return failure();
     Value source = createNdOp.getSource();
-    auto optionalLaneData = get2DLaneData(tdescTy);
+    auto optionalLaneData = getMaybeLaneData(tdescTy);
     assert(optionalLaneData && "Expected 2D lane data");
     auto laneData = optionalLaneData.value();
     int64_t innerLaneData = laneData[1];
@@ -160,7 +252,6 @@ class XeGPUCreateNdDescOpPattern final
     return success();
   }
 };
-} // namespace
 
 class XeGPULoadNdDescOpPattern final
     : public OpConversionPattern<xegpu::LoadNdOp> {
@@ -175,9 +266,8 @@ class XeGPULoadNdDescOpPattern final
     if (adaptorType == origTensorDescType)
       return failure();
     // Offsets must be adjusted based on innerLaneData.
-    auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
-    assert(optionalLaneData && "Expected 2D lane data");
-    int64_t innerLaneData = optionalLaneData.value()[1];
+    auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+    int64_t innerLaneData = laneData[1];
     auto offsets = loadNdOp.getMixedOffsets();
     if (offsets.empty())
       return rewriter.notifyMatchFailure(loadNdOp,
@@ -187,25 +277,82 @@ class XeGPULoadNdDescOpPattern final
         rewriter, loadNdOp.getLoc(),
         convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
         innerLaneData);
-    VectorType modifiedType =
-        VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
-    // Create a new LoadNdOp with modified offsets and type.
-    auto newLoadNdOp = xegpu::LoadNdOp::create(
-        rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
-        modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
-        loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
-        loadNdOp.getL3HintAttr());
-    // Bitcast back to the original type.
+    // Get the 2D data shape of this loadNdOp in its original type including
+    // array length.
+    SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+    // origDataShape.back() *= origTensorDescType.getArrayLength();
+    // Adjust the data shape based on innerLaneData.
+    origDataShape.back() /= innerLaneData;
+    // HW supported shape is the new tensor desc shape after conversion.
+    SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+    // Shape ratio is 2D and, it describes how many blocks need to be loaded in
+    // HW supported shape to cover the original shape.
+    auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
+                     .value(); // ratio must be defined if we reach here.
+    // Create a zero-initialized vector to hold all loaded blocks.
+    // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
+    VectorType origVectorType =
+        VectorType::get(origDataShape, adaptorType.getElementType());
+    Value data;
+    // Orig data shape is 3D for the array length case.
+    if (origTensorDescType.getArrayLength() > 1) {
+      SmallVector<int64_t> arrayLenDataShape(origDataShape);
+      arrayLenDataShape.insert(arrayLenDataShape.begin(),
+                               origTensorDescType.getArrayLength());
+      auto arrayLenVecType =
+          VectorType::get(arrayLenDataShape, adaptorType.getElementType());
+      data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
+                                       arrayLenVecType,
+                                       rewriter.getZeroAttr(arrayLenVecType));
+      for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+        Value slice = arith::ConstantOp::create(
+            rewriter, loadNdOp->getLoc(),
+            VectorType::get(origDataShape, adaptorType.getElementType()),
+            rewriter.getZeroAttr(origVectorType));
+        // Increse the Y offset for each array slice.
+        Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+                                       modifiedOffsets.back());
+        modifiedOffsets.back() =
+            arith::AddIOp::create(rewriter, loadNdOp->getLoc(), offsetY,
+                                  createConstantIndex(rewriter,
+                                                      loadNdOp->getLoc(),
+                                                      i * origDataShape[1]))
+                .getResult();
+        slice = generateLoads(
+            rewriter, cast<TypedValue<VectorType>>(slice), ratio,
+            modifiedOffsets, hwSupportedShape,
+            cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+            loadNdOp);
+        // Insert slice to data.
+        data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
+                                        data, ArrayRef<int64_t>{i});
+      }
+      // Cast back to the original type and replace all uses.
+      data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                       loadNdOp.getType(), data);
+      rewriter.replaceOp(loadNdOp, data);
+      return success();
+    }
+    data = arith::ConstantOp::create(
+        rewriter, loadNdOp->getLoc(),
+        VectorType::get(origDataShape, adaptorType.getElementType()),
+        rewriter.getZeroAttr(origVectorType));
+    data = generateLoads(
+        rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
+        hwSupportedShape,
+        cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+        loadNdOp);
     auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
-                                            loadNdOp.getType(), newLoadNdOp);
-    // Cast op must have the same layout as the original LoadNdOp result.
-    xegpu::setDistributeLayoutAttr(
-        castOp->getOpResult(0),
-        xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
-    rewriter.replaceOp(loadNdOp, castOp.getResult());
+                                            loadNdOp.getType(), data);
+    // // Cast op must have the same layout as the original LoadNdOp result.
+    // xegpu::setDistributeLayoutAttr(
+    //     castOp->getOpResult(0),
+    //     xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
+    rewriter.replaceOp(loadNdOp, castOp);
     return success();
   }
 };
+} // namespace
 
 void xegpu::populateXeGPUOptimizeTransposePatterns(
     RewritePatternSet &patterns) {
@@ -224,23 +371,15 @@ struct XeGPUOptimizeTransposePass final
     RewritePatternSet patterns(&context);
     ConversionTarget target(context);
 
-    auto checkValidInnerLaneData =
-        [](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
-      if (!optionalLaneData)
-        return true;
-      auto laneData = optionalLaneData.value();
-      return laneData[0] != 1 || laneData[1] == 1;
-    };
-
+    // CreateNdDescOp and LoadNdOp with invalid transpose layout must be
+    // converted.
     target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
         [&](xegpu::CreateNdDescOp createNdOp) {
-          auto optionalLaneData = get2DLaneData(createNdOp.getType());
-          return checkValidInnerLaneData(optionalLaneData);
+          return !hasInvalidTranposeLayout(createNdOp.getType());
         });
     target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
         [&](xegpu::LoadNdOp loadNdOp) {
-          auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
-          return checkValidInnerLaneData(optionalLaneData);
+          return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
         });
     converter.addConversion([](Type type) { return type; });
 
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 35c9294c32f04..a59ab2e14e36c 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -115,3 +115,47 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
   xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
   return
 }
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16>
+    -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+    -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b }
+      : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
+    %19 = vector.extract %6[0] : vector<32x16xf16> from vector<2x32x16xf16>
+    %20 = vector.extract %6[1] : vector<32x16xf16> from vector<2x32x16xf16>
+    %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %11 = vector.transpose %7, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %12 = vector.transpose %8, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %13 = vector.transpose %9, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %14 = vector.transpose %10, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+  } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+  xegpu.store_nd %4#0, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#1, %0[%c0, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#2, %0[%c16, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  return
+}

>From 35ca92befdf5830c2b7c25cfcf4b9456ec012a8b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 25 Oct 2025 00:38:06 +0000
Subject: [PATCH 6/8] working version

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 69 ++++++++++++++-----
 1 file changed, 50 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index a697c0433ed56..c73f2e4607482 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -296,18 +296,18 @@ class XeGPULoadNdDescOpPattern final
     Value data;
     // Orig data shape is 3D for the array length case.
     if (origTensorDescType.getArrayLength() > 1) {
-      SmallVector<int64_t> arrayLenDataShape(origDataShape);
-      arrayLenDataShape.insert(arrayLenDataShape.begin(),
-                               origTensorDescType.getArrayLength());
-      auto arrayLenVecType =
-          VectorType::get(arrayLenDataShape, adaptorType.getElementType());
-      data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
-                                       arrayLenVecType,
-                                       rewriter.getZeroAttr(arrayLenVecType));
+      // SmallVector<int64_t> arrayLenDataShape(origDataShape);
+      // arrayLenDataShape.insert(arrayLenDataShape.begin(),
+      //                          origTensorDescType.getArrayLength());
+      // auto arrayLenVecType =
+      //     VectorType::get(arrayLenDataShape, adaptorType.getElementType());
+      // auto  = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
+      //  arrayLenVecType,
+      //  rewriter.getZeroAttr(arrayLenVecType));
+      SmallVector<Value> arraySlices;
       for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
         Value slice = arith::ConstantOp::create(
-            rewriter, loadNdOp->getLoc(),
-            VectorType::get(origDataShape, adaptorType.getElementType()),
+            rewriter, loadNdOp->getLoc(), origVectorType,
             rewriter.getZeroAttr(origVectorType));
         // Increse the Y offset for each array slice.
         Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
@@ -323,14 +323,20 @@ class XeGPULoadNdDescOpPattern final
             modifiedOffsets, hwSupportedShape,
             cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
             loadNdOp);
-        // Insert slice to data.
-        data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
-                                        data, ArrayRef<int64_t>{i});
+        // // Insert slice to data.
+        // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
+        //                                 data, ArrayRef<int64_t>{i});
+        // Bitcast back to original load shape without array length.
+        auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+                                           origTensorDescType.getElementType());
+        slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                          bitcastType, slice);
+        arraySlices.push_back(slice);
       }
-      // Cast back to the original type and replace all uses.
-      data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
-                                       loadNdOp.getType(), data);
-      rewriter.replaceOp(loadNdOp, data);
+      // // Cast back to the original type and replace all uses.
+      // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+      //                                  loadNdOp.getType(), data);
+      rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
       return success();
     }
     data = arith::ConstantOp::create(
@@ -352,12 +358,33 @@ class XeGPULoadNdDescOpPattern final
     return success();
   }
 };
+
+class VectorExtractOpPattern final
+    : public OpConversionPattern<vector::ExtractOp> {
+public:
+  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (adaptor.getSource().size() == 1)
+      return failure();
+    auto mixedPos = extractOp.getMixedPosition();
+    if (mixedPos.size() != 1)
+      return failure();
+    auto mayBeInt = getConstantIntValue(mixedPos[0]);
+    if (!mayBeInt)
+      return failure();
+    rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+    return success();
+  }
+};
+
 } // namespace
 
 void xegpu::populateXeGPUOptimizeTransposePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
-      patterns.getContext());
+  patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+               VectorExtractOpPattern>(patterns.getContext());
 }
 
 namespace {
@@ -381,6 +408,10 @@ struct XeGPUOptimizeTransposePass final
         [&](xegpu::LoadNdOp loadNdOp) {
           return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
         });
+    target.addDynamicallyLegalOp<vector::ExtractOp>(
+        [&](vector::ExtractOp extractOp) {
+          return extractOp.getSourceVectorType().getRank() != 3;
+        });
     converter.addConversion([](Type type) { return type; });
 
     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,

>From 17fd7c8e6af94a9018e203a8c89f2f340c2763ca Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 28 Oct 2025 21:42:47 +0000
Subject: [PATCH 7/8] add tests

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     |  84 +++++++------
 .../Dialect/XeGPU/optimize-transpose.mlir     | 114 +++++++++++++++++-
 2 files changed, 154 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index c73f2e4607482..462751e16ace1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -79,9 +79,20 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
   return laneLayout;
 }
 
+static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
+                           ArrayRef<int64_t> laneData) {
+  if (laneLayout.size() != 2 || laneData.size() != 2)
+    return false;
+  if (laneLayout[0] == 1 || laneLayout[1] != 1)
+    return false;
+  if (laneData[0] != 1 || laneData[1] == 1)
+    return false;
+  return true;
+}
+
 // A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
 // lane[1] == 1), but inner lane data is not equal to [1, 1].
-static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
+static bool canBeOptimized(xegpu::TensorDescType tdescType) {
   // If the dtype is greater or equal to 32 bits, layout must be valid.
   int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
   if (elementTyBitwidth >= 32)
@@ -90,18 +101,12 @@ static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
   auto maybeLaneData = getMaybeLaneData(tdescType);
   if (!maybeLaneData || !maybeLaneLayout)
     return false;
-  auto laneData = maybeLaneData.value();
-  auto laneLayout = maybeLaneLayout.value();
-  if (laneLayout[0] == 1 || laneLayout[1] != 1)
-    return false;
-  if (laneData[0] != 1 || laneData[1] == 1)
-    return false;
-  return true;
+  return canBeOptimized(*maybeLaneLayout, *maybeLaneData);
 }
 
 static xegpu::TensorDescType
 tryConvertToTransposable(xegpu::TensorDescType tdescType) {
-  if (!hasInvalidTranposeLayout(tdescType))
+  if (!canBeOptimized(tdescType))
     return tdescType;
   auto laneData = getMaybeLaneData(tdescType).value();
   int64_t innerLaneData = laneData[1];
@@ -185,11 +190,17 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
           origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
           origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
           origLoadOp.getL3HintAttr());
+      // Set the layout for the loadOp.
+      auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+      xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
       // Insert the loaded block into the right position in data.
-      data = vector::InsertStridedSliceOp::create(
+      auto insertOp = vector::InsertStridedSliceOp::create(
           rewriter, loc, loadOp.getResult(), data,
           ArrayRef<int64_t>{localOffsetX, localOffsetY},
           ArrayRef<int64_t>{1, 1});
+      // InsertOp must have the same layout as newTensorDesc.
+      xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+      data = insertOp.getResult();
     }
   }
   return data;
@@ -288,7 +299,7 @@ class XeGPULoadNdDescOpPattern final
     // Shape ratio is 2D and, it describes how many blocks need to be loaded in
     // HW supported shape to cover the original shape.
     auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
-                     .value(); // ratio must be defined if we reach here.
+                     .value(); // `ratio` must be defined if we reach here.
     // Create a zero-initialized vector to hold all loaded blocks.
     // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
     VectorType origVectorType =
@@ -296,20 +307,12 @@ class XeGPULoadNdDescOpPattern final
     Value data;
     // Orig data shape is 3D for the array length case.
     if (origTensorDescType.getArrayLength() > 1) {
-      // SmallVector<int64_t> arrayLenDataShape(origDataShape);
-      // arrayLenDataShape.insert(arrayLenDataShape.begin(),
-      //                          origTensorDescType.getArrayLength());
-      // auto arrayLenVecType =
-      //     VectorType::get(arrayLenDataShape, adaptorType.getElementType());
-      // auto  = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
-      //  arrayLenVecType,
-      //  rewriter.getZeroAttr(arrayLenVecType));
       SmallVector<Value> arraySlices;
       for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
         Value slice = arith::ConstantOp::create(
             rewriter, loadNdOp->getLoc(), origVectorType,
             rewriter.getZeroAttr(origVectorType));
-        // Increse the Y offset for each array slice.
+        // Increase the Y offset for each array slice.
         Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
                                        modifiedOffsets.back());
         modifiedOffsets.back() =
@@ -323,19 +326,16 @@ class XeGPULoadNdDescOpPattern final
             modifiedOffsets, hwSupportedShape,
             cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
             loadNdOp);
-        // // Insert slice to data.
-        // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
-        //                                 data, ArrayRef<int64_t>{i});
-        // Bitcast back to original load shape without array length.
+        // BitCast back to original load shape without array length.
         auto bitcastType = VectorType::get(origTensorDescType.getShape(),
                                            origTensorDescType.getElementType());
-        slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
-                                          bitcastType, slice);
-        arraySlices.push_back(slice);
+        auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                                   bitcastType, slice);
+        // BitCastOp must have the same layout as the original loadNdOp.
+        xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+                                       origTensorDescType.getLayoutAttr());
+        arraySlices.push_back(bitCastOp.getResult());
       }
-      // // Cast back to the original type and replace all uses.
-      // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
-      //                                  loadNdOp.getType(), data);
       rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
       return success();
     }
@@ -348,13 +348,12 @@ class XeGPULoadNdDescOpPattern final
         hwSupportedShape,
         cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
         loadNdOp);
-    auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
-                                            loadNdOp.getType(), data);
-    // // Cast op must have the same layout as the original LoadNdOp result.
-    // xegpu::setDistributeLayoutAttr(
-    //     castOp->getOpResult(0),
-    //     xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
-    rewriter.replaceOp(loadNdOp, castOp);
+    auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                               loadNdOp.getType(), data);
+    // BitCastOp must have the same layout as the original loadNdOp.
+    xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+                                   origTensorDescType.getLayoutAttr());
+    rewriter.replaceOp(loadNdOp, bitCastOp);
     return success();
   }
 };
@@ -402,15 +401,20 @@ struct XeGPUOptimizeTransposePass final
     // converted.
     target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
         [&](xegpu::CreateNdDescOp createNdOp) {
-          return !hasInvalidTranposeLayout(createNdOp.getType());
+          return !canBeOptimized(createNdOp.getType());
         });
     target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
         [&](xegpu::LoadNdOp loadNdOp) {
-          return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
+          return !canBeOptimized(loadNdOp.getTensorDescType());
         });
     target.addDynamicallyLegalOp<vector::ExtractOp>(
         [&](vector::ExtractOp extractOp) {
-          return extractOp.getSourceVectorType().getRank() != 3;
+          auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+          if (!layout)
+            return true;
+          auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+          auto laneData = layout.getEffectiveLaneDataAsInt();
+          return !canBeOptimized(laneLayout, laneData);
         });
     converter.addConversion([](Type type) { return type; });
 
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index a59ab2e14e36c..7b3bd1bf8e4fe 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -1,5 +1,18 @@
+// RUN: mlir-opt -xegpu-optimize-transpose -canonicalize -split-input-file %s | FileCheck %s
 
-
+// CHECK-LABEL: func.func @no_scf(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK:         %[[C16:.*]] = arith.constant 16 : index
+// CHECK:         %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index
+// CHECK:         %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:         %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT:    %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:      : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK:         %[[BITCAST:.*]] = vector.bitcast %[[B]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -14,21 +27,51 @@ func.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8
 }
 
 // -----
+// CHECK-LABEL: func.func @no_scf_i8(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> {
+// CHECK:         %[[C16:.*]] = arith.constant 16 : index
+// CHECK:         %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index
+// CHECK:         %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:         %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:         %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:      : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK:         %[[T3:.*]] = vector.bitcast %[[T2]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 func.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
   %c0 = arith.constant 0 : index
   %c64 = arith.constant 64 : index
   %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
   %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
   %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
-  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
   return %6 : vector<8x16xi32>
 }
 
 
 // -----
+// CHECK-LABEL:   func.func @gemm_b_transpose(
+// CHECK-SAME:      %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[C256:.*]] = arith.constant 256 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1]
+// CHECK-SAME:        : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK:             %[[T7:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK-NEXT:        %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME:          !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:        %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<16x8xi32> to vector<16x16xf16>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -52,6 +95,24 @@ func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf1
 }
 
 // -----
+// CHECK-LABEL: func.func @nested_scf(
+// CHECK-SAME:     %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:          %[[C128:.*]] = arith.constant 128 : index
+// CHECK:          %[[C2:.*]] = arith.constant 2 : index
+// CHECK:          %[[C16:.*]] = arith.constant 16 : index
+// CHECK:          %[[C256:.*]] = arith.constant 256 : index
+// CHECK:          scf.for %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:            %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:            %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:            %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME:          -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:            %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK:              %[[T7:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK-NEXT:         %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]  {layout_result_0 = #xegpu.layout<
+// CHECK-SAME:          lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME:          !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:         %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<16x8xi32> to vector<16x16xf16>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -77,6 +138,31 @@ func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %a
 }
 
 // -----
+// CHECK-LABEL:   func.func @large_loads(
+// CHECK-SAME:      %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32>
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME:        -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK:             %[[T5:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK:             %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]}
+// CHECK-SAME:          : vector<32x8xi32> into vector<32x16xi32>
+// CHECK:             %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK:             %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]}
+// CHECK-SAME:          : vector<32x8xi32> into vector<32x16xi32>
+// CHECK:             %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x16xi32> to vector<32x32xf16>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -117,6 +203,26 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
 }
 
 // -----
+// CHECK-LABEL:  func.func @array_length(
+// CHECK-SAME:      %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 ->
+// CHECK-SAME:        !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK:             %[[T5:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK:             %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x8xi32> to vector<32x16xf16>
+// CHECK:             %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK:             %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x8xi32> to vector<32x16xf16>
 #a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
 #b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
 #bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -133,8 +239,8 @@ func.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %ar
     -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
     %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b }
       : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
-    %19 = vector.extract %6[0] : vector<32x16xf16> from vector<2x32x16xf16>
-    %20 = vector.extract %6[1] : vector<32x16xf16> from vector<2x32x16xf16>
+    %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+    %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
     %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
       : vector<32x16xf16> to vector<16x16xf16>
     %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }

>From cbcccf632b3d086f78e93e243d3847c3cbf3b62e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 29 Oct 2025 00:12:05 +0000
Subject: [PATCH 8/8] add comments

---
 .../Transforms/XeGPUOptimizeTranspose.cpp     | 40 +++++++++++--------
 1 file changed, 23 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 462751e16ace1..385a0dd470f03 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -46,7 +46,8 @@ struct Allowed2DShapeRange {
   int64_t minWidth, maxWidth, minHeight, maxHeight;
 };
 
-// TODO: Use uArch to get supported block ranges.
+/// Helper to get the size range of a 2D block that can be transposed by HW.
+/// TODO: Use uArch to get supported block ranges.
 static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
   switch (bitWidth) {
   case 32:
@@ -57,6 +58,7 @@ static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
   }
 }
 
+/// Get the 2D lane data from a tensor desc type if it exists.
 static std::optional<SmallVector<int64_t>>
 getMaybeLaneData(xegpu::TensorDescType tdescType) {
   auto layout = tdescType.getLayoutAttr();
@@ -68,6 +70,7 @@ getMaybeLaneData(xegpu::TensorDescType tdescType) {
   return laneData;
 }
 
+/// Get the 2D lane layout from a tensor desc type if it exists.
 static std::optional<SmallVector<int64_t>>
 getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
   auto layout = tdescType.getLayoutAttr();
@@ -79,6 +82,8 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
   return laneLayout;
 }
 
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
 static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
                            ArrayRef<int64_t> laneData) {
   if (laneLayout.size() != 2 || laneData.size() != 2)
@@ -90,8 +95,8 @@ static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
   return true;
 }
 
-// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
-// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
 static bool canBeOptimized(xegpu::TensorDescType tdescType) {
   // If the dtype is greater or equal to 32 bits, layout must be valid.
   int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
@@ -104,8 +109,9 @@ static bool canBeOptimized(xegpu::TensorDescType tdescType) {
   return canBeOptimized(*maybeLaneLayout, *maybeLaneData);
 }
 
-static xegpu::TensorDescType
-tryConvertToTransposable(xegpu::TensorDescType tdescType) {
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType) {
   if (!canBeOptimized(tdescType))
     return tdescType;
   auto laneData = getMaybeLaneData(tdescType).value();
@@ -143,11 +149,13 @@ tryConvertToTransposable(xegpu::TensorDescType tdescType) {
       tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
 }
 
+/// Helper to create a constant index value.
 static Value createConstantIndex(ConversionPatternRewriter &rewriter,
                                  Location loc, int64_t value) {
   return arith::ConstantIndexOp::create(rewriter, loc, value).getResult();
 }
 
+/// Helper to convert an OpFoldResult to Value.
 static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
                             OpFoldResult ofr) {
   std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
@@ -156,6 +164,7 @@ static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
   return llvm::cast<Value>(ofr);
 }
 
+/// Helper to divide a Value by a constant integer.
 static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
                               Value val, int64_t constant) {
   auto constantOp = createConstantIndex(rewriter, loc, constant);
@@ -164,7 +173,6 @@ static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
 
 static Value generateLoads(ConversionPatternRewriter &rewriter,
                            TypedValue<VectorType> data,
-                           SmallVector<int64_t> &shapeRatio,
                            SmallVector<OpFoldResult> offsets,
                            SmallVector<int64_t> &supportedShape,
                            TypedValue<xegpu::TensorDescType> newTensorDesc,
@@ -173,6 +181,11 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
   assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
   Value offsetX = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
   Value offsetY = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+  // Compute the ratio between original shape and supported shape. We need to
+  // generate loads in this ratio arrangement.
+  auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+                                      supportedShape)
+                        .value(); // `ratio` must be defined if we reach here.
   for (int64_t h = 0; h < shapeRatio[0]; ++h) {
     for (int64_t w = 0; w < shapeRatio[1]; ++w) {
       int64_t localOffsetX = h * supportedShape[0];
@@ -214,7 +227,7 @@ class XeGPUCreateNdDescOpPattern final
   matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto tdescTy = createNdOp.getType();
-    auto convertType = tryConvertToTransposable(tdescTy);
+    auto convertType = tryOptimize(tdescTy);
     if (convertType == tdescTy)
       return failure();
     auto strides = createNdOp.getMixedStrides();
@@ -291,17 +304,10 @@ class XeGPULoadNdDescOpPattern final
     // Get the 2D data shape of this loadNdOp in its original type including
     // array length.
     SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
-    // origDataShape.back() *= origTensorDescType.getArrayLength();
     // Adjust the data shape based on innerLaneData.
     origDataShape.back() /= innerLaneData;
     // HW supported shape is the new tensor desc shape after conversion.
     SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
-    // Shape ratio is 2D and, it describes how many blocks need to be loaded in
-    // HW supported shape to cover the original shape.
-    auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
-                     .value(); // `ratio` must be defined if we reach here.
-    // Create a zero-initialized vector to hold all loaded blocks.
-    // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
     VectorType origVectorType =
         VectorType::get(origDataShape, adaptorType.getElementType());
     Value data;
@@ -322,8 +328,8 @@ class XeGPULoadNdDescOpPattern final
                                                       i * origDataShape[1]))
                 .getResult();
         slice = generateLoads(
-            rewriter, cast<TypedValue<VectorType>>(slice), ratio,
-            modifiedOffsets, hwSupportedShape,
+            rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+            hwSupportedShape,
             cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
             loadNdOp);
         // BitCast back to original load shape without array length.
@@ -344,7 +350,7 @@ class XeGPULoadNdDescOpPattern final
         VectorType::get(origDataShape, adaptorType.getElementType()),
         rewriter.getZeroAttr(origVectorType));
     data = generateLoads(
-        rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
+        rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
         hwSupportedShape,
         cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
         loadNdOp);



More information about the Mlir-commits mailing list