[Mlir-commits] [mlir] 9703bda - [mlir][xegpu] Add OptimizeBlockLoads pass. (#165483)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 4 13:15:36 PST 2025


Author: Charitha Saumya
Date: 2025-11-04T13:15:32-08:00
New Revision: 9703bda95b088bb6a455ef9faffdb41c537aff2f

URL: https://github.com/llvm/llvm-project/commit/9703bda95b088bb6a455ef9faffdb41c537aff2f
DIFF: https://github.com/llvm/llvm-project/commit/9703bda95b088bb6a455ef9faffdb41c537aff2f.diff

LOG: [mlir][xegpu] Add OptimizeBlockLoads pass.  (#165483)

This pass rewrites certain xegpu `CreateNd` and `LoadNd` operations that
feeds into `vector.transpose` to more optimal form to improve
performance. Specifically, low precision (bitwidth < 32) `LoadNd` ops
that feeds into transpose ops are rewritten to i32 loads with a valid
transpose layout such that later passes can use the load with transpose
HW feature to accelerate such load ops.

**Update:**
Pass is renamed to `OptimizeBlockLoads ` because later we plan to add
the array length optimization into this pass as well. This will break
down a larger load (like `32x32xf16`) into more DPAS-favorable array
length loads (`32x16xf16` with array length = 2). Both these
optmizations require rewriting `CreateNd` and `LoadNd` and it makes
sense to have a common pass for both.

Added: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
    mlir/test/Dialect/XeGPU/optimize-transpose.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
    mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
    mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index eb05628d4772b..e42799689e490 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -85,4 +85,16 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
                            "scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
 }
 
+def XeGPUOptimizeBlockLoads : Pass<"xegpu-optimize-block-loads"> {
+  let summary = "Optimize XeGPU block load operations";
+  let description = [{
+    This pass rewrites XeGPU loadNd operations into more optimal forms
+    to improve performance. This includes,
+    - Rewriting transpose B loads into more optimal forms to use HW block
+      transpose instructions for better performance.
+  }];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index a480195eebd00..1776a209d0bf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -61,7 +61,8 @@ struct UnrollOptions {
 
 /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-
+/// Appends patterns for optimizing block load operations into `patterns`.
+void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
 /// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 620a2fe43d682..58092c3bb9ed2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -166,6 +166,15 @@ SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
 SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
                                               ArrayRef<OpFoldResult> lhs,
                                               ArrayRef<OpFoldResult> rhs);
+
+/// Helper Function to find a proper instruction multiple for the user-supplied
+/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
+/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
+/// array length).
+template <typename T>
+int getLargestDivisor(T dim, ArrayRef<T> candidates,
+                      ArrayRef<T> candidateMultiples = {});
+
 } // namespace xegpu
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..29b645feab2c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUWgToSgDistribute.cpp
   XeGPUPropagateLayout.cpp
   XeGPUVectorLinearize.cpp
+  XeGPUOptimizeBlockLoads.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
new file mode 100644
index 0000000000000..4dc5ea4f7bb24
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -0,0 +1,490 @@
+//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-block-loads"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+/// Get the 2D lane data from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
+  auto layout = tdescType.getLayoutAttr();
+  if (!layout)
+    return std::nullopt;
+  auto laneData = layout.getEffectiveLaneDataAsInt();
+  if (laneData.size() != 2)
+    return std::nullopt;
+  return laneData;
+}
+
+/// Get the 2D lane layout from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+  auto layout = tdescType.getLayoutAttr();
+  if (!layout)
+    return std::nullopt;
+  auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+  if (laneLayout.size() != 2)
+    return std::nullopt;
+  return laneLayout;
+}
+
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// Example:
+///     !xegpu.tensor_desc<16x16xf16,
+///         #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form)
+/// indicating that this is a load that requires transpose effect. However,
+/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from
+/// the inner dimension. We convert this to a optimized form by converting the
+/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the
+/// later lowering easily use the load with transpose instruction.
+static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout,
+                                       ArrayRef<int64_t> laneData) {
+  if (laneLayout.size() != 2 || laneData.size() != 2)
+    return false;
+  if (laneLayout[0] == 1 || laneLayout[1] != 1)
+    return false;
+  if (laneData[0] != 1 || laneData[1] == 1)
+    return false;
+  return true;
+}
+
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
+static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
+  // If the dtype is greater or equal to 32 bits, layout must be valid.
+  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+  if (elementTyBitwidth >= 32)
+    return false;
+  auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+  auto maybeLaneData = getMaybeLaneData(tdescType);
+  if (!maybeLaneData || !maybeLaneLayout)
+    return false;
+  return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
+}
+
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
+                                         const uArch *targetuArch) {
+  if (!canBeOptimizedForTranspose(tdescType))
+    return tdescType;
+  auto laneData = getMaybeLaneData(tdescType)
+                      .value(); // Lane data must exist if we reach here.
+  int64_t innerLaneData = laneData[1];
+  int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+  // Required shape is total shape of the vector result that this tensor desc
+  // must eventually load after adjusting for the new bitwidth and array
+  // length.
+  SmallVector<int64_t> requiredShape(tdescType.getShape());
+  requiredShape.back() =
+      requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+  int newBitWidth = elementTyBitwidth * innerLaneData;
+  Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+  // Supported shape is the max transpose shape that can be supported by
+  // hardware that is less than or equal to required shape.
+  auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
+      targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad));
+  auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
+      newElemTy, /** has transform */ false, /** has transpose */ true);
+  // If no HW params found, return the original type.
+  if (!maybeHWParams)
+    return tdescType;
+  auto [widths, heights, counts] = maybeHWParams.value();
+  // TODO: Currently we expect array length to be 1 for transpose case.
+  if (counts.size() != 1 || counts[0] != 1)
+    return tdescType;
+  int arrayLen = counts[0];
+  int supportedHeight =
+      xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights);
+  int supportedWidth =
+      xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths);
+  // If no supported height or width found, return the original type.
+  if (supportedHeight == -1 || supportedWidth == -1)
+    return tdescType;
+
+  SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+  xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+      tdescType.getContext(),
+      tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+  // Array length can not be larger than 1 for transpose case.
+  return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
+                                    tdescType.getBoundaryCheck(),
+                                    tdescType.getMemorySpace(), newLayout);
+}
+
+/// Helper to convert an OpFoldResult to Value.
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+                            OpFoldResult ofr) {
+  std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+  if (mayBeInt)
+    return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult();
+  return llvm::cast<Value>(ofr);
+}
+
+/// Helper to divide a Value by a constant integer.
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+                              Value val, int64_t constant) {
+  // If the constant is a power of 2, use right shift for division.
+  if (llvm::isPowerOf2_64(constant)) {
+    int64_t shiftAmount = llvm::Log2_64(constant);
+    return arith::ShRUIOp::create(
+               rewriter, loc, val,
+               arith::ConstantIndexOp::create(rewriter, loc, shiftAmount)
+                   .getResult())
+        .getResult();
+  }
+  auto constantOp =
+      arith::ConstantIndexOp::create(rewriter, loc, constant).getResult();
+  return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+/// This function takes a larger register block `data` and generates multiple
+/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block
+/// starting from `offsets`.
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+                           TypedValue<VectorType> data,
+                           SmallVector<OpFoldResult> offsets,
+                           TypedValue<xegpu::TensorDescType> newTensorDesc,
+                           xegpu::LoadNdOp origLoadOp) {
+  Location loc = data.getLoc();
+  assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+  Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+  Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+  SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape());
+  // Compute the ratio between original shape and supported shape. We need to
+  // generate loads in this ratio arrangement.
+  auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+                                      supportedShape)
+                        .value(); // `ratio` must be defined if we reach here.
+  for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+    for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+      int64_t localOffsetDim0 = h * supportedShape[0];
+      int64_t localOffsetDim1 = w * supportedShape[1];
+      Value loadOffsetX = arith::AddIOp::create(
+          rewriter, loc, offsetDim0,
+          arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0)
+              .getResult());
+      Value loadOffsetY = arith::AddIOp::create(
+          rewriter, loc, offsetDim1,
+          arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1)
+              .getResult());
+      auto loadOp = xegpu::LoadNdOp::create(
+          rewriter, loc,
+          VectorType::get(supportedShape, data.getType().getElementType()),
+          newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+          origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+          origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+          origLoadOp.getL3HintAttr());
+      // Set the layout for the loadOp.
+      auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+      xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
+      // Insert the loaded block into the right position in data.
+      auto insertOp = vector::InsertStridedSliceOp::create(
+          rewriter, loc, loadOp.getResult(), data,
+          ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1},
+          ArrayRef<int64_t>{1, 1});
+      // InsertOp must have the same layout as newTensorDesc.
+      xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+      data = insertOp.getResult();
+    }
+  }
+  return data;
+}
+
+/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a
+/// new CreateNdDescOp with optimized tensor desc type. This involves extracting
+/// the base pointer from the original memory source and adjusting the shape and
+/// strides of the tensor desc to fit with the new optimized transpose layout.
+class XeGPUCreateNdDescOpPattern final
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+public:
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto tdescTy = createNdOp.getType();
+    // Get the target uArch info.
+    auto chipStr = xegpu::getChipStr(createNdOp);
+    // Check if the chip is supported.
+    assert(
+        chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") &&
+        "Expecting target chip to be pvc or bmg for transpose optimization.");
+    const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value());
+
+    auto convertType = tryOptimize(tdescTy, targetuArch);
+    if (convertType == tdescTy)
+      return failure();
+    auto strides = createNdOp.getMixedStrides();
+    auto maybeConstInnerStride = getConstantIntValue(strides.back());
+    // Only row-major memrefs are expected for now.
+    if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+      return rewriter.notifyMatchFailure(
+          createNdOp, "Expecting row-major memref for transpose optimization.");
+    Value source = createNdOp.getSource();
+    auto optionalLaneData = getMaybeLaneData(tdescTy);
+    assert(optionalLaneData && "Expected 2D lane data");
+    auto laneData = optionalLaneData.value();
+    int64_t innerLaneData = laneData[1];
+    auto memrefType = dyn_cast<MemRefType>(source.getType());
+    // Inner dimension of the shape must be adjusted based on innerLaneData.
+    SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+    modifiedShape.back() = divideByConstant(
+        rewriter, createNdOp.getLoc(),
+        convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+        innerLaneData);
+    // Similarly, second to last stride must be adjusted.
+    assert(strides.size() >= 2 &&
+           "Expected at least 2 strides for CreateNdDescOp");
+    SmallVector<OpFoldResult> modifiedStrides(strides);
+    modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+        rewriter, createNdOp.getLoc(),
+        convertToValue(rewriter, createNdOp.getLoc(),
+                       modifiedStrides[modifiedStrides.size() - 2]),
+        innerLaneData);
+
+    // If the source is a static memref, we need to extract the pointer to
+    // base address.
+    if (memrefType && memrefType.hasStaticShape()) {
+      auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+          rewriter, createNdOp.getLoc(), source);
+      source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
+                                          rewriter.getI64Type(),
+                                          extractOp.getResult())
+                   .getResult();
+    }
+    // Create a new CreateNdDescOp with the modified shape and converted type.
+    auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+        rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+        modifiedStrides);
+    rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
+    return success();
+  }
+};
+
+/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for
+/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the
+/// adjusted tensor desc type. This can result in multiple LoadNdOps being
+/// generated to fill in the original load shape.
+class XeGPULoadNdDescOpPattern final
+    : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto origTensorDescType = loadNdOp.getTensorDescType();
+    auto adaptorType =
+        cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+    if (adaptorType == origTensorDescType)
+      return failure();
+    // Offsets must be adjusted based on innerLaneData.
+    auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+    int64_t innerLaneData = laneData[1];
+    auto offsets = loadNdOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(loadNdOp,
+                                         "Expecting offsets in LoadNd");
+    SmallVector<OpFoldResult> modifiedOffsets(offsets);
+    modifiedOffsets.back() = divideByConstant(
+        rewriter, loadNdOp.getLoc(),
+        convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+        innerLaneData);
+    // Get the 2D data shape of this loadNdOp in its original type including
+    // array length.
+    SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+    // Adjust the data shape based on innerLaneData.
+    origDataShape.back() /= innerLaneData;
+    // HW supported shape is the new tensor desc shape after conversion.
+    SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+    VectorType origVectorType =
+        VectorType::get(origDataShape, adaptorType.getElementType());
+    Value data;
+    // Orig data shape is 3D for the array length case.
+    if (origTensorDescType.getArrayLength() > 1) {
+      SmallVector<Value> arraySlices;
+      for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+        Value slice = arith::ConstantOp::create(
+            rewriter, loadNdOp->getLoc(), origVectorType,
+            rewriter.getZeroAttr(origVectorType));
+        // Increase the Y offset for each array slice.
+        Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+                                       modifiedOffsets.back());
+        modifiedOffsets.back() =
+            arith::AddIOp::create(
+                rewriter, loadNdOp->getLoc(), offsetY,
+                arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(),
+                                               i * origDataShape[1])
+                    .getResult())
+                .getResult();
+        slice = generateLoads(
+            rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+            cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+            loadNdOp);
+        // BitCast back to original load shape without array length.
+        auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+                                           origTensorDescType.getElementType());
+        auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                                   bitcastType, slice);
+        // BitCastOp must have the same layout as the original loadNdOp.
+        xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+                                       origTensorDescType.getLayoutAttr());
+        arraySlices.push_back(bitCastOp.getResult());
+      }
+      rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
+      return success();
+    }
+    data = arith::ConstantOp::create(
+        rewriter, loadNdOp->getLoc(),
+        VectorType::get(origDataShape, adaptorType.getElementType()),
+        rewriter.getZeroAttr(origVectorType));
+    data = generateLoads(
+        rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
+        cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+        loadNdOp);
+    auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+                                               loadNdOp.getType(), data);
+    // BitCastOp must have the same layout as the original loadNdOp.
+    xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+                                   origTensorDescType.getLayoutAttr());
+    rewriter.replaceOp(loadNdOp, bitCastOp);
+    return success();
+  }
+};
+
+/// Vector ExtractOp must be processed if the original tensor desc type has
+/// array length greater than 1. In this case, the LoadNdOp is replaced with
+/// multiple LoadNdOps for each array slice making the extraction unnecessary.
+/// In this case, we simply remove the ExtractOp.
+class VectorExtractOpPattern final
+    : public OpConversionPattern<vector::ExtractOp> {
+public:
+  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check if the source of the extraction is split to multiple values.
+    if (adaptor.getSource().size() == 1)
+      return failure();
+    auto mixedPos = extractOp.getMixedPosition();
+    if (mixedPos.size() != 1)
+      return failure();
+    auto mayBeInt = getConstantIntValue(mixedPos[0]);
+    if (!mayBeInt)
+      return failure();
+    rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+    return success();
+  }
+};
+
+} // namespace
+
+void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+               VectorExtractOpPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeBlockLoadsPass final
+    : public xegpu::impl::XeGPUOptimizeBlockLoadsBase<
+          XeGPUOptimizeBlockLoadsPass> {
+  void runOnOperation() override {
+    MLIRContext &context = getContext();
+    TypeConverter converter;
+    RewritePatternSet patterns(&context);
+    ConversionTarget target(context);
+
+    // This pass is only meant for PVC and BMG targets. If unsupported target
+    // is found, exit early.
+    bool isTargetSupported = false;
+    getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
+      auto chipStr = xegpu::getChipStr(funcOp);
+      if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg"))
+        isTargetSupported = true;
+    });
+
+    if (!isTargetSupported) {
+      DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets."
+             << "\n";
+      return;
+    }
+
+    // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be
+    // converted.
+    target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+        [&](xegpu::CreateNdDescOp createNdOp) {
+          return !canBeOptimizedForTranspose(createNdOp.getType());
+        });
+    target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+        [&](xegpu::LoadNdOp loadNdOp) {
+          return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
+        });
+    // Vector ExtractOps can have optimizable layouts if they extract from
+    // LoadNdOps with array length greater than 1. These ExtractOps must be
+    // converted.
+    target.addDynamicallyLegalOp<vector::ExtractOp>(
+        [&](vector::ExtractOp extractOp) {
+          auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+          if (!layout)
+            return true;
+          auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+          auto laneData = layout.getEffectiveLaneDataAsInt();
+          return !canBeOptimizedForTranspose(laneLayout, laneData);
+        });
+    converter.addConversion([](Type type) { return type; });
+
+    target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+                           vector::VectorDialect>();
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns)))) {
+      DBGS() << "Optimize block loads pass failed.\n";
+      return signalPassFailure();
+    }
+  }
+};
+
+} // namespace

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 14c49e7f45706..4e1a539771d2f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -204,28 +204,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
   using Lattice::Lattice;
 };
 
-/// Helper Function to find a proper instruction multiple for the user-supplied
-/// sg-level data shape. `candidates` are uArch allowed shapes.
-/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
-template <typename T>
-int getLargestDivisor(T dim, ArrayRef<T> candidates,
-                      ArrayRef<T> candidateMultiples = {}) {
-  static_assert(std::is_integral<T>::value, "T must be an integer type");
-  int largest = -1;
-  SmallVector<T> multiples = {1};
-  if (!candidateMultiples.empty())
-    multiples =
-        SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
-  for (T candidate : candidates) {
-    for (T multiple : multiples) {
-      int value = static_cast<int>(candidate * multiple);
-      if (value != 0 && dim % value == 0 && value > largest)
-        largest = value;
-    }
-  }
-  return largest;
-}
-
 /// Helper Functions to get default layouts. A `default layout` is a layout that
 /// is assigned to a value when the layout is not fixed by some anchor operation
 /// (like DPAS).
@@ -505,7 +483,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
     prefetch.emitWarning("No known block params found for the element type.");
   auto [bWidth, bHeight, bCount] = blockWHC.value();
   SmallVector<int> instData;
-  int instWidth = getLargestDivisor(
+  int instWidth = xegpu::getLargestDivisor(
       static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
       bCount);
   if (instWidth == -1)
@@ -514,7 +492,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   if (tdescTy.getRank() == 1)
     instData = {instWidth};
   else {
-    int instHeight = getLargestDivisor(
+    int instHeight = xegpu::getLargestDivisor(
         static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
     if (instHeight == -1)
       prefetch.emitWarning(
@@ -634,7 +612,7 @@ void LayoutInfoPropagation::visitDpasOp(
   const unsigned dataALen = aTy.getShape().front();
   auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
   const int maxALen =
-      getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+      xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
   if (maxALen == -1)
     dpas.emitWarning(
         "No suitable instruction multiple found for the given shape.");
@@ -642,7 +620,7 @@ void LayoutInfoPropagation::visitDpasOp(
   const unsigned dataBLen = bTy.getShape().back();
   auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
   const int maxBLen =
-      getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+      xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
   if (maxBLen == -1)
     dpas.emitWarning(
         "No suitable instruction multiple found for the given shape.");
@@ -662,7 +640,7 @@ void LayoutInfoPropagation::visitDpasOp(
     const unsigned dataCLen = bTy.getShape().back();
     auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
     const int maxCLen =
-        getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+        xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
     if (maxCLen == -1)
       dpas.emitWarning(
           "No suitable instruction multiple found for the given shape.");
@@ -691,7 +669,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
     store.emitWarning("No known block params found for the element type.");
   auto [bWidth, bHeight, bCount] = blockWHC.value();
   SmallVector<int> instData;
-  int instWidth = getLargestDivisor(
+  int instWidth = xegpu::getLargestDivisor(
       static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
       bCount);
   if (instWidth == -1)
@@ -700,7 +678,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
   if (dataTy.getRank() == 1)
     instData = {instWidth};
   else {
-    int instHeight = getLargestDivisor(
+    int instHeight = xegpu::getLargestDivisor(
         static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
     if (instHeight == -1)
       store.emitWarning(

diff  --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d575a415a3035..de9e09d427665 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -555,3 +555,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
   results.append(addElementwise(builder, loc, a, b));
   return results;
 }
+
+template <typename T>
+int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
+                             ArrayRef<T> candidateMultiples) {
+  static_assert(std::is_integral<T>::value, "T must be an integer type");
+  int largest = -1;
+  SmallVector<T> multiples = {1};
+  if (!candidateMultiples.empty())
+    multiples =
+        SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
+  for (T candidate : candidates) {
+    for (T multiple : multiples) {
+      int value = static_cast<int>(candidate * multiple);
+      if (value != 0 && dim % value == 0 && value > largest)
+        largest = value;
+    }
+  }
+  return largest;
+}
+
+/// Explicit instantiations
+template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
+                                           ArrayRef<int> candidateMultiples);
+template int
+xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
+                                   ArrayRef<unsigned> candidateMultiples);

diff  --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
new file mode 100644
index 0000000000000..24a0de6ed48a5
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -0,0 +1,280 @@
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc'  \
+// RUN:   --xegpu-optimize-block-loads --canonicalize --cse --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: gpu.func @no_scf(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK:         %[[C16:.*]] = arith.constant 16 : index
+// CHECK:         %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index
+// CHECK:         %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:         %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT:    %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:      : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK:         %[[BITCAST:.*]] = vector.bitcast %[[B]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+  %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  gpu.return %6 : vector<8x16xf32>
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @no_scf_i8(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> {
+// CHECK:         %[[C16:.*]] = arith.constant 16 : index
+// CHECK:         %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index
+// CHECK:         %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:         %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:         %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:      : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK:         %[[T3:.*]] = vector.bitcast %[[T2]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
+  %c0 = arith.constant 0 : index
+  %c64 = arith.constant 64 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
+  %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+  gpu.return %6 : vector<8x16xi32>
+}
+}
+
+
+// -----
+// CHECK-LABEL:   gpu.func @gemm_b_transpose(
+// CHECK-SAME:      %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[C256:.*]] = arith.constant 256 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1]
+// CHECK-SAME:        : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK:             %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT:        %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME:          !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:        %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+  %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+    %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+    %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %8 : vector<8x16xf32>
+  } {layout_result_0 = #a}
+  xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @nested_scf(
+// CHECK-SAME:     %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:          %[[C128:.*]] = arith.constant 128 : index
+// CHECK:          %[[C1:.*]] = arith.constant 1 : index
+// CHECK:          %[[C16:.*]] = arith.constant 16 : index
+// CHECK:          %[[C256:.*]] = arith.constant 256 : index
+// CHECK:          scf.for %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK:            %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:            %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:            %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME:          -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:            %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK:              %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT:         %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]  {layout_result_0 = #xegpu.layout<
+// CHECK-SAME:          lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME:          !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:         %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c256 = arith.constant 256 : index
+  scf.for %arg8 = %c0 to %c256 step %c16 {
+    %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+    %1 = xegpu.load_nd %0[%arg8, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+    %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+    %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+    %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+      %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+      %6 = xegpu.load_nd %3[%arg8, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+      %7 = vector.transpose %6, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+      %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      scf.yield %8 : vector<8x16xf32>
+    } {layout_result_0 = #a}
+    xegpu.store_nd %4, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  }
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL:   gpu.func @large_loads(
+// CHECK-SAME:      %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32>
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME:        -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK:             %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK:             %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]}
+// CHECK-SAME:          : vector<32x8xi32> into vector<32x16xi32>
+// CHECK:             %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK:             %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
+// CHECK-SAME:          {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]}
+// CHECK-SAME:          : vector<32x8xi32> into vector<32x16xi32>
+// CHECK:             %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x16xi32> to vector<32x32xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+  %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+    -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+    %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x32xf16> to vector<16x16xf16>
+    %11 = vector.transpose %7, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %12 = vector.transpose %8, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %13 = vector.transpose %9, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %14 = vector.transpose %10, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+  } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+  xegpu.store_nd %4#0, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#1, %0[%c0, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#2, %0[%c16, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL:  gpu.func @array_length(
+// CHECK-SAME:      %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+// CHECK:           %[[C128:.*]] = arith.constant 128 : index
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK:           %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK:           %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 ->
+// CHECK-SAME:        !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK:           %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK:             %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK:             %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x8xi32> to vector<32x16xf16>
+// CHECK:             %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK:             %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]]  {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME:          : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK:             %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME:          : vector<32x8xi32> to vector<32x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %c32 = arith.constant 32 : index
+  %c256 = arith.constant 256 : index
+  %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+  %1 = xegpu.load_nd %0[%c0, %c0]  { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16>
+    -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+  %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+    -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+    %6 = xegpu.load_nd %3[%c0, %arg3]  { layout_result_0 = #b }
+      : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
+    %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+    %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+    %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+      : vector<32x16xf16> to vector<16x16xf16>
+    %11 = vector.transpose %7, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %12 = vector.transpose %8, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %13 = vector.transpose %9, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %14 = vector.transpose %10, [1, 0]  { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+    %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+  } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+  xegpu.store_nd %4#0, %0[%c0, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#1, %0[%c0, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#2, %0[%c16, %c0]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  xegpu.store_nd %4#3, %0[%c16, %c16]  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+  gpu.return
+}
+}


        


More information about the Mlir-commits mailing list