[Mlir-commits] [mlir] [mlir][xegpu] Add OptimizeTranspose pass. (PR #165483)
Charitha Saumya
llvmlistbot at llvm.org
Tue Oct 28 17:13:03 PDT 2025
https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/165483
>From 9d0341dae9da51c916405c6e14adf4dac9d20af5 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 21 Oct 2025 18:49:43 +0000
Subject: [PATCH 1/8] add pass
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 10 +++
.../Dialect/XeGPU/Transforms/Transforms.h | 3 +-
.../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
.../Transforms/XeGPUOptimizeTranspose.cpp | 69 +++++++++++++++++++
4 files changed, 82 insertions(+), 1 deletion(-)
create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..d0185159f16ac 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -80,4 +80,14 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
}
+def XeGPUOptimizeTranspose : Pass<"xegpu-optimize-transpose"> {
+ let summary = "Optimize XeGPU loadNd operations feeding into vector.transpose";
+ let description = [{
+ This pass rewrites XeGPU loadNd operations that feed into vector.transpose
+ into more optimal forms to improve performance.
+ }];
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+ "vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index a480195eebd00..cab1598457cbe 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -61,7 +61,8 @@ struct UnrollOptions {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-
+/// Appends patterns for optimizing transpose operations into `patterns`.
+void populateXeGPUOptimizeTransposePatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..a2ecbc3374ba9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
+ XeGPUOptimizeTranspose.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
new file mode 100644
index 0000000000000..69ba9e6fca7ba
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -0,0 +1,69 @@
+//===- XeGPUOptimizeTranspose.cpp - XeGPU optimize transpose ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZETRANSPOSE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-transpose"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return success();
+ }
+};
+} // namespace
+
+void xegpu::populateXeGPUOptimizeTransposePatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<XeGPULoadNdPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeTransposePass final
+ : public xegpu::impl::XeGPUOptimizeTransposeBase<
+ XeGPUOptimizeTransposePass> {
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ DBGS() << "Optimize transpose pass failed.\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
>From 76f7323b526079c585dda525d2fa7b3974ed4a49 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 22 Oct 2025 15:20:08 +0000
Subject: [PATCH 2/8] save work
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 62 +++++++++++++++++--
.../Dialect/XeGPU/optimize-transpose.mlir | 18 ++++++
2 files changed, 76 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/XeGPU/optimize-transpose.mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 69ba9e6fca7ba..83edbad3afce2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -12,8 +12,10 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include <optional>
namespace mlir {
namespace xegpu {
@@ -29,12 +31,28 @@ using namespace mlir;
namespace {
-class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
+static std::optional<SmallVector<int64_t>>
+get2DLaneData(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ if (laneData.size() != 2)
+ return std::nullopt;
+ return laneData;
+}
+
+class XeGPUCreateNdDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
public:
- using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
+ matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ auto tdescTy = createNdOp.getType();
+ auto convertType = this->getTypeConverter()->convertType(tdescTy);
+ if (convertType == tdescTy)
+ return failure();
return success();
}
};
@@ -42,7 +60,7 @@ class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
void xegpu::populateXeGPUOptimizeTransposePatterns(
RewritePatternSet &patterns) {
- patterns.add<XeGPULoadNdPattern>(patterns.getContext());
+ patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
}
namespace {
@@ -55,6 +73,42 @@ struct XeGPUOptimizeTransposePass final
TypeConverter converter;
RewritePatternSet patterns(&context);
ConversionTarget target(context);
+
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [](xegpu::CreateNdDescOp createNdOp) {
+ auto optionalLaneData = get2DLaneData(createNdOp.getType());
+ if (!optionalLaneData)
+ return true;
+ auto laneData = optionalLaneData.value();
+ return laneData[0] != 1 || laneData[1] == 1;
+ });
+
+ converter.addConversion([](xegpu::TensorDescType tdescType) {
+ auto optionalLaneData = get2DLaneData(tdescType);
+ if (!optionalLaneData)
+ return tdescType;
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ if (laneData[0] == 1 && innerLaneData != 1) {
+ int elementTyBitwidth =
+ tdescType.getElementType().getIntOrFloatBitWidth();
+ assert(elementTyBitwidth < 32 &&
+ "Expected element type bitwidth < 32 with laneData[1] != 1");
+ SmallVector<int64_t> newShape(tdescType.getShape());
+ newShape.back() = newShape.back() / innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(),
+ elementTyBitwidth * innerLaneData);
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ return xegpu::TensorDescType::get(
+ newShape, newElemTy, tdescType.getArrayLength(),
+ tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
+ newLayout);
+ }
+ return tdescType;
+ });
+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
new file mode 100644
index 0000000000000..529211e286902
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -0,0 +1,18 @@
+func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ return
+}
>From 43c35bef094cd8754f332be17689dbef9e1c8516 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 22 Oct 2025 22:35:50 +0000
Subject: [PATCH 3/8] add some tests
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 188 +++++++++++++++---
.../Dialect/XeGPU/optimize-transpose.mlir | 119 ++++++++++-
2 files changed, 264 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 83edbad3afce2..9792da7d6ce79 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -6,13 +6,19 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <optional>
@@ -42,6 +48,46 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
return laneData;
}
+static xegpu::TensorDescType
+getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
+ auto optionalLaneData = get2DLaneData(tdescType);
+ if (!optionalLaneData)
+ return tdescType;
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ if (laneData[0] == 1 && innerLaneData != 1) {
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ assert(elementTyBitwidth < 32 &&
+ "Expected element type bitwidth < 32 with laneData[1] != 1");
+ SmallVector<int64_t> newShape(tdescType.getShape());
+ newShape.back() = newShape.back() / innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(),
+ elementTyBitwidth * innerLaneData);
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ return xegpu::TensorDescType::get(
+ newShape, newElemTy, tdescType.getArrayLength(),
+ tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
+ }
+ return tdescType;
+}
+
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+ OpFoldResult ofr) {
+ std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+ if (mayBeInt)
+ return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
+ return llvm::cast<Value>(ofr);
+}
+
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+ Value val, int64_t constant) {
+ auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
+ return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
+ .getResult();
+}
+
class XeGPUCreateNdDescOpPattern final
: public OpConversionPattern<xegpu::CreateNdDescOp> {
public:
@@ -50,17 +96,106 @@ class XeGPUCreateNdDescOpPattern final
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tdescTy = createNdOp.getType();
- auto convertType = this->getTypeConverter()->convertType(tdescTy);
+ auto convertType = getModifiedTensorDescType(tdescTy);
if (convertType == tdescTy)
return failure();
+ auto strides = createNdOp.getMixedStrides();
+ auto maybeConstInnerStride = getConstantIntValue(strides.back());
+ // Only row-major memrefs are expected for now.
+ if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+ return failure();
+ Value source = createNdOp.getSource();
+ auto optionalLaneData = get2DLaneData(tdescTy);
+ assert(optionalLaneData && "Expected 2D lane data");
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ auto memrefType = dyn_cast<MemRefType>(source.getType());
+ // Inner dimension of the shape must be adjusted based on innerLaneData.
+ SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+ modifiedShape.back() = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+ innerLaneData);
+ // Similarly, second to last stride must be adjusted.
+ assert(strides.size() >= 2 &&
+ "Expected at least 2 strides for CreateNdDescOp");
+ SmallVector<OpFoldResult> modifiedStrides(strides);
+ modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(),
+ modifiedStrides[modifiedStrides.size() - 2]),
+ innerLaneData);
+
+ // If the source is a static memref, we need to extract the pointer to
+ // base address.
+ if (memrefType && memrefType.hasStaticShape()) {
+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, createNdOp.getLoc(), source);
+ source = arith::IndexCastOp::create(
+ rewriter, createNdOp.getLoc(),
+ IntegerType::get(rewriter.getContext(), 64),
+ extractOp.getResult())
+ .getResult();
+ }
+ // Create a new CreateNdDescOp with the modified shape and converted type.
+ auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+ rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+ modifiedStrides);
+ rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
return success();
}
};
} // namespace
+class XeGPULoadNdDescOpPattern final
+ : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto origTensorDescType = loadNdOp.getTensorDescType();
+ auto adaptorType =
+ cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+ if (adaptorType == origTensorDescType)
+ return failure();
+ // Offsets must be adjusted based on innerLaneData.
+ auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
+ assert(optionalLaneData && "Expected 2D lane data");
+ int64_t innerLaneData = optionalLaneData.value()[1];
+ auto offsets = loadNdOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(loadNdOp,
+ "Expecting offsets in LoadNd");
+ SmallVector<OpFoldResult> modifiedOffsets(offsets);
+ modifiedOffsets.back() = divideByConstant(
+ rewriter, loadNdOp.getLoc(),
+ convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+ innerLaneData);
+ VectorType modifiedType =
+ VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
+ // Create a new LoadNdOp with modified offsets and type.
+ auto newLoadNdOp = xegpu::LoadNdOp::create(
+ rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
+ modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
+ loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
+ loadNdOp.getL3HintAttr());
+ // Bitcast back to the original type.
+ auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), newLoadNdOp);
+ // Cast op must have the same layout as the original LoadNdOp result.
+ xegpu::setDistributeLayoutAttr(
+ castOp->getOpResult(0),
+ xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
+ rewriter.replaceOp(loadNdOp, castOp.getResult());
+ return success();
+ }
+};
+
void xegpu::populateXeGPUOptimizeTransposePatterns(
RewritePatternSet &patterns) {
- patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
+ patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
+ patterns.getContext());
}
namespace {
@@ -74,41 +209,28 @@ struct XeGPUOptimizeTransposePass final
RewritePatternSet patterns(&context);
ConversionTarget target(context);
+ auto checkValidInnerLaneData =
+ [](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
+ if (!optionalLaneData)
+ return true;
+ auto laneData = optionalLaneData.value();
+ return laneData[0] != 1 || laneData[1] == 1;
+ };
+
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
- [](xegpu::CreateNdDescOp createNdOp) {
+ [&](xegpu::CreateNdDescOp createNdOp) {
auto optionalLaneData = get2DLaneData(createNdOp.getType());
- if (!optionalLaneData)
- return true;
- auto laneData = optionalLaneData.value();
- return laneData[0] != 1 || laneData[1] == 1;
+ return checkValidInnerLaneData(optionalLaneData);
});
+ target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+ [&](xegpu::LoadNdOp loadNdOp) {
+ auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
+ return checkValidInnerLaneData(optionalLaneData);
+ });
+ converter.addConversion([](Type type) { return type; });
- converter.addConversion([](xegpu::TensorDescType tdescType) {
- auto optionalLaneData = get2DLaneData(tdescType);
- if (!optionalLaneData)
- return tdescType;
- auto laneData = optionalLaneData.value();
- int64_t innerLaneData = laneData[1];
- if (laneData[0] == 1 && innerLaneData != 1) {
- int elementTyBitwidth =
- tdescType.getElementType().getIntOrFloatBitWidth();
- assert(elementTyBitwidth < 32 &&
- "Expected element type bitwidth < 32 with laneData[1] != 1");
- SmallVector<int64_t> newShape(tdescType.getShape());
- newShape.back() = newShape.back() / innerLaneData;
- Type newElemTy = IntegerType::get(tdescType.getContext(),
- elementTyBitwidth * innerLaneData);
- xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
- tdescType.getContext(),
- tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
- return xegpu::TensorDescType::get(
- newShape, newElemTy, tdescType.getArrayLength(),
- tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
- newLayout);
- }
- return tdescType;
- });
-
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+ vector::VectorDialect>();
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 529211e286902..0bc74d75a4ad0 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -1,18 +1,117 @@
+
+
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ return %6 : vector<8x16xf32>
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+func.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
+ %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ return %6 : vector<8x16xi32>
+}
+
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c256 = arith.constant 256 : index
- %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
- %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
%4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
- %5 = xegpu.load_nd %2[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
- %6 = xegpu.load_nd %3[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
- %7 = vector.transpose %6, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
- %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
scf.yield %8 : vector<8x16xf32>
- } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
- xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ return
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+ %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ return
+}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ scf.for %arg8 = %c0 to %c256 step %c16 {
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ }
return
}
>From f79d2a2b68fc337fe431196c08c7b3e1b52a0256 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 23 Oct 2025 17:46:55 +0000
Subject: [PATCH 4/8] add some tests
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 15 ++++++
.../Dialect/XeGPU/optimize-transpose.mlir | 50 +++++++++----------
2 files changed, 40 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 9792da7d6ce79..e0a9e3c221400 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -35,6 +35,21 @@ namespace xegpu {
using namespace mlir;
+struct TransposableBlockRange {
+ int minWidth, maxWidth, minHeight, maxHeight;
+};
+
+// TODO: Use uArch to get supported block ranges.
+static TransposableBlockRange getBlockRange(int bitWidth) {
+ switch (bitWidth) {
+ case 32:
+ return {/**min width**/ 1, /**max width**/ 8, /**min height**/ 1,
+ /**max height**/ 32};
+ default:
+ llvm_unreachable("Add support for other element bitwidths");
+ }
+}
+
namespace {
static std::optional<SmallVector<int64_t>>
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 0bc74d75a4ad0..35c9294c32f04 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -51,6 +51,31 @@ func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf1
return
}
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ scf.for %arg8 = %c0 to %c256 step %c16 {
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ }
+ return
+}
+
// -----
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
@@ -90,28 +115,3 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
return
}
-
-// -----
-#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
- %c0 = arith.constant 0 : index
- %c16 = arith.constant 16 : index
- %c256 = arith.constant 256 : index
- scf.for %arg8 = %c0 to %c256 step %c16 {
- %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
- %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
- %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
- %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
- %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
- %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
- %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
- %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
- %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
- scf.yield %8 : vector<8x16xf32>
- } {layout_result_0 = #a}
- xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
- }
- return
-}
>From ca5d9024135f3e49c79e3dc503cd715bca4dc38d Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Fri, 24 Oct 2025 23:37:14 +0000
Subject: [PATCH 5/8] save work
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 265 +++++++++++++-----
.../Dialect/XeGPU/optimize-transpose.mlir | 44 +++
2 files changed, 246 insertions(+), 63 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index e0a9e3c221400..a697c0433ed56 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -9,18 +9,23 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
#include <optional>
namespace mlir {
@@ -35,12 +40,14 @@ namespace xegpu {
using namespace mlir;
-struct TransposableBlockRange {
- int minWidth, maxWidth, minHeight, maxHeight;
+namespace {
+
+struct Allowed2DShapeRange {
+ int64_t minWidth, maxWidth, minHeight, maxHeight;
};
// TODO: Use uArch to get supported block ranges.
-static TransposableBlockRange getBlockRange(int bitWidth) {
+static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
switch (bitWidth) {
case 32:
return {/**min width**/ 1, /**max width**/ 8, /**min height**/ 1,
@@ -50,10 +57,8 @@ static TransposableBlockRange getBlockRange(int bitWidth) {
}
}
-namespace {
-
static std::optional<SmallVector<int64_t>>
-get2DLaneData(xegpu::TensorDescType tdescType) {
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
auto layout = tdescType.getLayoutAttr();
if (!layout)
return std::nullopt;
@@ -63,44 +68,131 @@ get2DLaneData(xegpu::TensorDescType tdescType) {
return laneData;
}
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (laneLayout.size() != 2)
+ return std::nullopt;
+ return laneLayout;
+}
+
+// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
+// lane[1] == 1), but inner lane data is not equal to [1, 1].
+static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
+ // If the dtype is greater or equal to 32 bits, layout must be valid.
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ if (elementTyBitwidth >= 32)
+ return false;
+ auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+ auto maybeLaneData = getMaybeLaneData(tdescType);
+ if (!maybeLaneData || !maybeLaneLayout)
+ return false;
+ auto laneData = maybeLaneData.value();
+ auto laneLayout = maybeLaneLayout.value();
+ if (laneLayout[0] == 1 || laneLayout[1] != 1)
+ return false;
+ if (laneData[0] != 1 || laneData[1] == 1)
+ return false;
+ return true;
+}
+
static xegpu::TensorDescType
-getModifiedTensorDescType(xegpu::TensorDescType tdescType) {
- auto optionalLaneData = get2DLaneData(tdescType);
- if (!optionalLaneData)
+tryConvertToTransposable(xegpu::TensorDescType tdescType) {
+ if (!hasInvalidTranposeLayout(tdescType))
return tdescType;
- auto laneData = optionalLaneData.value();
+ auto laneData = getMaybeLaneData(tdescType).value();
int64_t innerLaneData = laneData[1];
- if (laneData[0] == 1 && innerLaneData != 1) {
- int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
- assert(elementTyBitwidth < 32 &&
- "Expected element type bitwidth < 32 with laneData[1] != 1");
- SmallVector<int64_t> newShape(tdescType.getShape());
- newShape.back() = newShape.back() / innerLaneData;
- Type newElemTy = IntegerType::get(tdescType.getContext(),
- elementTyBitwidth * innerLaneData);
- xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
- tdescType.getContext(),
- tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
- return xegpu::TensorDescType::get(
- newShape, newElemTy, tdescType.getArrayLength(),
- tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
- }
- return tdescType;
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ // Required shape is total shape of the vector result that this tensor desc
+ // must eventually load after adjusting for the new bitwidth and array
+ // length.
+ SmallVector<int64_t> requiredShape(tdescType.getShape());
+ requiredShape.back() =
+ requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+ int newBitWidth = elementTyBitwidth * innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+ // Supported shape is the max transpose shape that can be supported by
+ // hardware that is less than or equal to required shape.
+ auto supportedHeight = std::min(
+ requiredShape[0], getTransposableBlockRange(newBitWidth).maxHeight);
+ auto supportedWidth = std::min(
+ requiredShape[1], getTransposableBlockRange(newBitWidth).maxWidth);
+ SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+
+ // Required shape must be multiple of supported shape. Otherwise, we can not
+ // optimize it.
+ // TODO: Supported shape can be adjusted to handle non-multiple cases.
+ if (requiredShape[0] % supportedShape[0] != 0 ||
+ requiredShape[1] % supportedShape[1] != 0)
+ return tdescType;
+
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ // Array length can not be larger than 1 for transpose case.
+ return xegpu::TensorDescType::get(
+ supportedShape, newElemTy, /**array length**/ 1,
+ tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
+}
+
+static Value createConstantIndex(ConversionPatternRewriter &rewriter,
+ Location loc, int64_t value) {
+ return arith::ConstantIndexOp::create(rewriter, loc, value).getResult();
}
static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
OpFoldResult ofr) {
std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
if (mayBeInt)
- return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt);
+ return createConstantIndex(rewriter, loc, *mayBeInt);
return llvm::cast<Value>(ofr);
}
static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
Value val, int64_t constant) {
- auto constantOp = arith::ConstantIndexOp::create(rewriter, loc, constant);
- return arith::DivUIOp::create(rewriter, loc, val, constantOp.getResult())
- .getResult();
+ auto constantOp = createConstantIndex(rewriter, loc, constant);
+ return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> data,
+ SmallVector<int64_t> &shapeRatio,
+ SmallVector<OpFoldResult> offsets,
+ SmallVector<int64_t> &supportedShape,
+ TypedValue<xegpu::TensorDescType> newTensorDesc,
+ xegpu::LoadNdOp origLoadOp) {
+ Location loc = data.getLoc();
+ assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+ Value offsetX = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+ Value offsetY = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+ for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+ for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+ int64_t localOffsetX = h * supportedShape[0];
+ int64_t localOffsetY = w * supportedShape[1];
+ Value loadOffsetX = arith::AddIOp::create(
+ rewriter, loc, offsetX,
+ createConstantIndex(rewriter, loc, localOffsetX));
+ Value loadOffsetY = arith::AddIOp::create(
+ rewriter, loc, offsetY,
+ createConstantIndex(rewriter, loc, localOffsetY));
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc,
+ VectorType::get(supportedShape, data.getType().getElementType()),
+ newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+ origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+ origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+ origLoadOp.getL3HintAttr());
+ // Insert the loaded block into the right position in data.
+ data = vector::InsertStridedSliceOp::create(
+ rewriter, loc, loadOp.getResult(), data,
+ ArrayRef<int64_t>{localOffsetX, localOffsetY},
+ ArrayRef<int64_t>{1, 1});
+ }
+ }
+ return data;
}
class XeGPUCreateNdDescOpPattern final
@@ -111,7 +203,7 @@ class XeGPUCreateNdDescOpPattern final
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tdescTy = createNdOp.getType();
- auto convertType = getModifiedTensorDescType(tdescTy);
+ auto convertType = tryConvertToTransposable(tdescTy);
if (convertType == tdescTy)
return failure();
auto strides = createNdOp.getMixedStrides();
@@ -120,7 +212,7 @@ class XeGPUCreateNdDescOpPattern final
if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
return failure();
Value source = createNdOp.getSource();
- auto optionalLaneData = get2DLaneData(tdescTy);
+ auto optionalLaneData = getMaybeLaneData(tdescTy);
assert(optionalLaneData && "Expected 2D lane data");
auto laneData = optionalLaneData.value();
int64_t innerLaneData = laneData[1];
@@ -160,7 +252,6 @@ class XeGPUCreateNdDescOpPattern final
return success();
}
};
-} // namespace
class XeGPULoadNdDescOpPattern final
: public OpConversionPattern<xegpu::LoadNdOp> {
@@ -175,9 +266,8 @@ class XeGPULoadNdDescOpPattern final
if (adaptorType == origTensorDescType)
return failure();
// Offsets must be adjusted based on innerLaneData.
- auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
- assert(optionalLaneData && "Expected 2D lane data");
- int64_t innerLaneData = optionalLaneData.value()[1];
+ auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+ int64_t innerLaneData = laneData[1];
auto offsets = loadNdOp.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(loadNdOp,
@@ -187,25 +277,82 @@ class XeGPULoadNdDescOpPattern final
rewriter, loadNdOp.getLoc(),
convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
innerLaneData);
- VectorType modifiedType =
- VectorType::get(adaptorType.getShape(), adaptorType.getElementType());
- // Create a new LoadNdOp with modified offsets and type.
- auto newLoadNdOp = xegpu::LoadNdOp::create(
- rewriter, loadNdOp->getLoc(), modifiedType, adaptor.getTensorDesc(),
- modifiedOffsets, loadNdOp.getPackedAttr(), loadNdOp.getTransposeAttr(),
- loadNdOp.getL1HintAttr(), loadNdOp.getL2HintAttr(),
- loadNdOp.getL3HintAttr());
- // Bitcast back to the original type.
+ // Get the 2D data shape of this loadNdOp in its original type including
+ // array length.
+ SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+ // origDataShape.back() *= origTensorDescType.getArrayLength();
+ // Adjust the data shape based on innerLaneData.
+ origDataShape.back() /= innerLaneData;
+ // HW supported shape is the new tensor desc shape after conversion.
+ SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+ // Shape ratio is 2D and, it describes how many blocks need to be loaded in
+ // HW supported shape to cover the original shape.
+ auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
+ .value(); // ratio must be defined if we reach here.
+ // Create a zero-initialized vector to hold all loaded blocks.
+ // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
+ VectorType origVectorType =
+ VectorType::get(origDataShape, adaptorType.getElementType());
+ Value data;
+ // Orig data shape is 3D for the array length case.
+ if (origTensorDescType.getArrayLength() > 1) {
+ SmallVector<int64_t> arrayLenDataShape(origDataShape);
+ arrayLenDataShape.insert(arrayLenDataShape.begin(),
+ origTensorDescType.getArrayLength());
+ auto arrayLenVecType =
+ VectorType::get(arrayLenDataShape, adaptorType.getElementType());
+ data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
+ arrayLenVecType,
+ rewriter.getZeroAttr(arrayLenVecType));
+ for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+ Value slice = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(),
+ VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter.getZeroAttr(origVectorType));
+ // Increse the Y offset for each array slice.
+ Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+ modifiedOffsets.back());
+ modifiedOffsets.back() =
+ arith::AddIOp::create(rewriter, loadNdOp->getLoc(), offsetY,
+ createConstantIndex(rewriter,
+ loadNdOp->getLoc(),
+ i * origDataShape[1]))
+ .getResult();
+ slice = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(slice), ratio,
+ modifiedOffsets, hwSupportedShape,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ // Insert slice to data.
+ data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
+ data, ArrayRef<int64_t>{i});
+ }
+ // Cast back to the original type and replace all uses.
+ data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), data);
+ rewriter.replaceOp(loadNdOp, data);
+ return success();
+ }
+ data = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(),
+ VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter.getZeroAttr(origVectorType));
+ data = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
+ hwSupportedShape,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
- loadNdOp.getType(), newLoadNdOp);
- // Cast op must have the same layout as the original LoadNdOp result.
- xegpu::setDistributeLayoutAttr(
- castOp->getOpResult(0),
- xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
- rewriter.replaceOp(loadNdOp, castOp.getResult());
+ loadNdOp.getType(), data);
+ // // Cast op must have the same layout as the original LoadNdOp result.
+ // xegpu::setDistributeLayoutAttr(
+ // castOp->getOpResult(0),
+ // xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
+ rewriter.replaceOp(loadNdOp, castOp);
return success();
}
};
+} // namespace
void xegpu::populateXeGPUOptimizeTransposePatterns(
RewritePatternSet &patterns) {
@@ -224,23 +371,15 @@ struct XeGPUOptimizeTransposePass final
RewritePatternSet patterns(&context);
ConversionTarget target(context);
- auto checkValidInnerLaneData =
- [](std::optional<SmallVector<int64_t>> optionalLaneData) -> bool {
- if (!optionalLaneData)
- return true;
- auto laneData = optionalLaneData.value();
- return laneData[0] != 1 || laneData[1] == 1;
- };
-
+ // CreateNdDescOp and LoadNdOp with invalid transpose layout must be
+ // converted.
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
[&](xegpu::CreateNdDescOp createNdOp) {
- auto optionalLaneData = get2DLaneData(createNdOp.getType());
- return checkValidInnerLaneData(optionalLaneData);
+ return !hasInvalidTranposeLayout(createNdOp.getType());
});
target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
[&](xegpu::LoadNdOp loadNdOp) {
- auto optionalLaneData = get2DLaneData(loadNdOp.getTensorDescType());
- return checkValidInnerLaneData(optionalLaneData);
+ return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
});
converter.addConversion([](Type type) { return type; });
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index 35c9294c32f04..a59ab2e14e36c 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -115,3 +115,47 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
return
}
+
+// -----
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+func.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16>
+ -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b }
+ : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
+ %19 = vector.extract %6[0] : vector<32x16xf16> from vector<2x32x16xf16>
+ %20 = vector.extract %6[1] : vector<32x16xf16> from vector<2x32x16xf16>
+ %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ return
+}
>From 35ca92befdf5830c2b7c25cfcf4b9456ec012a8b Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Sat, 25 Oct 2025 00:38:06 +0000
Subject: [PATCH 6/8] working version
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 69 ++++++++++++++-----
1 file changed, 50 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index a697c0433ed56..c73f2e4607482 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -296,18 +296,18 @@ class XeGPULoadNdDescOpPattern final
Value data;
// Orig data shape is 3D for the array length case.
if (origTensorDescType.getArrayLength() > 1) {
- SmallVector<int64_t> arrayLenDataShape(origDataShape);
- arrayLenDataShape.insert(arrayLenDataShape.begin(),
- origTensorDescType.getArrayLength());
- auto arrayLenVecType =
- VectorType::get(arrayLenDataShape, adaptorType.getElementType());
- data = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
- arrayLenVecType,
- rewriter.getZeroAttr(arrayLenVecType));
+ // SmallVector<int64_t> arrayLenDataShape(origDataShape);
+ // arrayLenDataShape.insert(arrayLenDataShape.begin(),
+ // origTensorDescType.getArrayLength());
+ // auto arrayLenVecType =
+ // VectorType::get(arrayLenDataShape, adaptorType.getElementType());
+ // auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
+ // arrayLenVecType,
+ // rewriter.getZeroAttr(arrayLenVecType));
+ SmallVector<Value> arraySlices;
for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
Value slice = arith::ConstantOp::create(
- rewriter, loadNdOp->getLoc(),
- VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter, loadNdOp->getLoc(), origVectorType,
rewriter.getZeroAttr(origVectorType));
// Increse the Y offset for each array slice.
Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
@@ -323,14 +323,20 @@ class XeGPULoadNdDescOpPattern final
modifiedOffsets, hwSupportedShape,
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
loadNdOp);
- // Insert slice to data.
- data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
- data, ArrayRef<int64_t>{i});
+ // // Insert slice to data.
+ // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
+ // data, ArrayRef<int64_t>{i});
+ // Bitcast back to original load shape without array length.
+ auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+ origTensorDescType.getElementType());
+ slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ bitcastType, slice);
+ arraySlices.push_back(slice);
}
- // Cast back to the original type and replace all uses.
- data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
- loadNdOp.getType(), data);
- rewriter.replaceOp(loadNdOp, data);
+ // // Cast back to the original type and replace all uses.
+ // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ // loadNdOp.getType(), data);
+ rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
return success();
}
data = arith::ConstantOp::create(
@@ -352,12 +358,33 @@ class XeGPULoadNdDescOpPattern final
return success();
}
};
+
+class VectorExtractOpPattern final
+ : public OpConversionPattern<vector::ExtractOp> {
+public:
+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (adaptor.getSource().size() == 1)
+ return failure();
+ auto mixedPos = extractOp.getMixedPosition();
+ if (mixedPos.size() != 1)
+ return failure();
+ auto mayBeInt = getConstantIntValue(mixedPos[0]);
+ if (!mayBeInt)
+ return failure();
+ rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+ return success();
+ }
+};
+
} // namespace
void xegpu::populateXeGPUOptimizeTransposePatterns(
RewritePatternSet &patterns) {
- patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern>(
- patterns.getContext());
+ patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+ VectorExtractOpPattern>(patterns.getContext());
}
namespace {
@@ -381,6 +408,10 @@ struct XeGPUOptimizeTransposePass final
[&](xegpu::LoadNdOp loadNdOp) {
return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
});
+ target.addDynamicallyLegalOp<vector::ExtractOp>(
+ [&](vector::ExtractOp extractOp) {
+ return extractOp.getSourceVectorType().getRank() != 3;
+ });
converter.addConversion([](Type type) { return type; });
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
>From 17fd7c8e6af94a9018e203a8c89f2f340c2763ca Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Tue, 28 Oct 2025 21:42:47 +0000
Subject: [PATCH 7/8] add tests
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 84 +++++++------
.../Dialect/XeGPU/optimize-transpose.mlir | 114 +++++++++++++++++-
2 files changed, 154 insertions(+), 44 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index c73f2e4607482..462751e16ace1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -79,9 +79,20 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
return laneLayout;
}
+static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
+ ArrayRef<int64_t> laneData) {
+ if (laneLayout.size() != 2 || laneData.size() != 2)
+ return false;
+ if (laneLayout[0] == 1 || laneLayout[1] != 1)
+ return false;
+ if (laneData[0] != 1 || laneData[1] == 1)
+ return false;
+ return true;
+}
+
// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
// lane[1] == 1), but inner lane data is not equal to [1, 1].
-static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
+static bool canBeOptimized(xegpu::TensorDescType tdescType) {
// If the dtype is greater or equal to 32 bits, layout must be valid.
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
if (elementTyBitwidth >= 32)
@@ -90,18 +101,12 @@ static bool hasInvalidTranposeLayout(xegpu::TensorDescType tdescType) {
auto maybeLaneData = getMaybeLaneData(tdescType);
if (!maybeLaneData || !maybeLaneLayout)
return false;
- auto laneData = maybeLaneData.value();
- auto laneLayout = maybeLaneLayout.value();
- if (laneLayout[0] == 1 || laneLayout[1] != 1)
- return false;
- if (laneData[0] != 1 || laneData[1] == 1)
- return false;
- return true;
+ return canBeOptimized(*maybeLaneLayout, *maybeLaneData);
}
static xegpu::TensorDescType
tryConvertToTransposable(xegpu::TensorDescType tdescType) {
- if (!hasInvalidTranposeLayout(tdescType))
+ if (!canBeOptimized(tdescType))
return tdescType;
auto laneData = getMaybeLaneData(tdescType).value();
int64_t innerLaneData = laneData[1];
@@ -185,11 +190,17 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
origLoadOp.getL3HintAttr());
+ // Set the layout for the loadOp.
+ auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+ xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
// Insert the loaded block into the right position in data.
- data = vector::InsertStridedSliceOp::create(
+ auto insertOp = vector::InsertStridedSliceOp::create(
rewriter, loc, loadOp.getResult(), data,
ArrayRef<int64_t>{localOffsetX, localOffsetY},
ArrayRef<int64_t>{1, 1});
+ // InsertOp must have the same layout as newTensorDesc.
+ xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+ data = insertOp.getResult();
}
}
return data;
@@ -288,7 +299,7 @@ class XeGPULoadNdDescOpPattern final
// Shape ratio is 2D and, it describes how many blocks need to be loaded in
// HW supported shape to cover the original shape.
auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
- .value(); // ratio must be defined if we reach here.
+ .value(); // `ratio` must be defined if we reach here.
// Create a zero-initialized vector to hold all loaded blocks.
// TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
VectorType origVectorType =
@@ -296,20 +307,12 @@ class XeGPULoadNdDescOpPattern final
Value data;
// Orig data shape is 3D for the array length case.
if (origTensorDescType.getArrayLength() > 1) {
- // SmallVector<int64_t> arrayLenDataShape(origDataShape);
- // arrayLenDataShape.insert(arrayLenDataShape.begin(),
- // origTensorDescType.getArrayLength());
- // auto arrayLenVecType =
- // VectorType::get(arrayLenDataShape, adaptorType.getElementType());
- // auto = arith::ConstantOp::create(rewriter, loadNdOp->getLoc(),
- // arrayLenVecType,
- // rewriter.getZeroAttr(arrayLenVecType));
SmallVector<Value> arraySlices;
for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
Value slice = arith::ConstantOp::create(
rewriter, loadNdOp->getLoc(), origVectorType,
rewriter.getZeroAttr(origVectorType));
- // Increse the Y offset for each array slice.
+ // Increase the Y offset for each array slice.
Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
modifiedOffsets.back());
modifiedOffsets.back() =
@@ -323,19 +326,16 @@ class XeGPULoadNdDescOpPattern final
modifiedOffsets, hwSupportedShape,
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
loadNdOp);
- // // Insert slice to data.
- // data = vector::InsertOp::create(rewriter, loadNdOp->getLoc(), slice,
- // data, ArrayRef<int64_t>{i});
- // Bitcast back to original load shape without array length.
+ // BitCast back to original load shape without array length.
auto bitcastType = VectorType::get(origTensorDescType.getShape(),
origTensorDescType.getElementType());
- slice = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
- bitcastType, slice);
- arraySlices.push_back(slice);
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ bitcastType, slice);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ arraySlices.push_back(bitCastOp.getResult());
}
- // // Cast back to the original type and replace all uses.
- // data = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
- // loadNdOp.getType(), data);
rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
return success();
}
@@ -348,13 +348,12 @@ class XeGPULoadNdDescOpPattern final
hwSupportedShape,
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
loadNdOp);
- auto castOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
- loadNdOp.getType(), data);
- // // Cast op must have the same layout as the original LoadNdOp result.
- // xegpu::setDistributeLayoutAttr(
- // castOp->getOpResult(0),
- // xegpu::getDistributeLayoutAttr(loadNdOp.getResult()));
- rewriter.replaceOp(loadNdOp, castOp);
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), data);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ rewriter.replaceOp(loadNdOp, bitCastOp);
return success();
}
};
@@ -402,15 +401,20 @@ struct XeGPUOptimizeTransposePass final
// converted.
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
[&](xegpu::CreateNdDescOp createNdOp) {
- return !hasInvalidTranposeLayout(createNdOp.getType());
+ return !canBeOptimized(createNdOp.getType());
});
target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
[&](xegpu::LoadNdOp loadNdOp) {
- return !hasInvalidTranposeLayout(loadNdOp.getTensorDescType());
+ return !canBeOptimized(loadNdOp.getTensorDescType());
});
target.addDynamicallyLegalOp<vector::ExtractOp>(
[&](vector::ExtractOp extractOp) {
- return extractOp.getSourceVectorType().getRank() != 3;
+ auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+ if (!layout)
+ return true;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ return !canBeOptimized(laneLayout, laneData);
});
converter.addConversion([](Type type) { return type; });
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
index a59ab2e14e36c..7b3bd1bf8e4fe 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -1,5 +1,18 @@
+// RUN: mlir-opt -xegpu-optimize-transpose -canonicalize -split-input-file %s | FileCheck %s
-
+// CHECK-LABEL: func.func @no_scf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -14,21 +27,51 @@ func.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8
}
// -----
+// CHECK-LABEL: func.func @no_scf_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
func.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
%1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
%2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
- %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
return %6 : vector<8x16xi32>
}
// -----
+// CHECK-LABEL: func.func @gemm_b_transpose(
+// CHECK-SAME: %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1]
+// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -52,6 +95,24 @@ func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf1
}
// -----
+// CHECK-LABEL: func.func @nested_scf(
+// CHECK-SAME: %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] {layout_result_0 = #xegpu.layout<
+// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -77,6 +138,31 @@ func.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %a
}
// -----
+// CHECK-LABEL: func.func @large_loads(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32>
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -117,6 +203,26 @@ func.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
}
// -----
+// CHECK-LABEL: func.func @array_length(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 ->
+// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.divui %[[K]], %[[C2]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@@ -133,8 +239,8 @@ func.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %ar
-> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
%6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b }
: !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
- %19 = vector.extract %6[0] : vector<32x16xf16> from vector<2x32x16xf16>
- %20 = vector.extract %6[1] : vector<32x16xf16> from vector<2x32x16xf16>
+ %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+ %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
%7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
: vector<32x16xf16> to vector<16x16xf16>
%8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
>From cbcccf632b3d086f78e93e243d3847c3cbf3b62e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 29 Oct 2025 00:12:05 +0000
Subject: [PATCH 8/8] add comments
---
.../Transforms/XeGPUOptimizeTranspose.cpp | 40 +++++++++++--------
1 file changed, 23 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
index 462751e16ace1..385a0dd470f03 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp
@@ -46,7 +46,8 @@ struct Allowed2DShapeRange {
int64_t minWidth, maxWidth, minHeight, maxHeight;
};
-// TODO: Use uArch to get supported block ranges.
+/// Helper to get the size range of a 2D block that can be transposed by HW.
+/// TODO: Use uArch to get supported block ranges.
static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
switch (bitWidth) {
case 32:
@@ -57,6 +58,7 @@ static Allowed2DShapeRange getTransposableBlockRange(int bitWidth) {
}
}
+/// Get the 2D lane data from a tensor desc type if it exists.
static std::optional<SmallVector<int64_t>>
getMaybeLaneData(xegpu::TensorDescType tdescType) {
auto layout = tdescType.getLayoutAttr();
@@ -68,6 +70,7 @@ getMaybeLaneData(xegpu::TensorDescType tdescType) {
return laneData;
}
+/// Get the 2D lane layout from a tensor desc type if it exists.
static std::optional<SmallVector<int64_t>>
getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
auto layout = tdescType.getLayoutAttr();
@@ -79,6 +82,8 @@ getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
return laneLayout;
}
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
ArrayRef<int64_t> laneData) {
if (laneLayout.size() != 2 || laneData.size() != 2)
@@ -90,8 +95,8 @@ static bool canBeOptimized(ArrayRef<int64_t> laneLayout,
return true;
}
-// A transpose layout is invalid if lane layout is transposed (lane[0] != 1 &&
-// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
static bool canBeOptimized(xegpu::TensorDescType tdescType) {
// If the dtype is greater or equal to 32 bits, layout must be valid.
int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
@@ -104,8 +109,9 @@ static bool canBeOptimized(xegpu::TensorDescType tdescType) {
return canBeOptimized(*maybeLaneLayout, *maybeLaneData);
}
-static xegpu::TensorDescType
-tryConvertToTransposable(xegpu::TensorDescType tdescType) {
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType) {
if (!canBeOptimized(tdescType))
return tdescType;
auto laneData = getMaybeLaneData(tdescType).value();
@@ -143,11 +149,13 @@ tryConvertToTransposable(xegpu::TensorDescType tdescType) {
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(), newLayout);
}
+/// Helper to create a constant index value.
static Value createConstantIndex(ConversionPatternRewriter &rewriter,
Location loc, int64_t value) {
return arith::ConstantIndexOp::create(rewriter, loc, value).getResult();
}
+/// Helper to convert an OpFoldResult to Value.
static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
OpFoldResult ofr) {
std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
@@ -156,6 +164,7 @@ static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
return llvm::cast<Value>(ofr);
}
+/// Helper to divide a Value by a constant integer.
static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
Value val, int64_t constant) {
auto constantOp = createConstantIndex(rewriter, loc, constant);
@@ -164,7 +173,6 @@ static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
static Value generateLoads(ConversionPatternRewriter &rewriter,
TypedValue<VectorType> data,
- SmallVector<int64_t> &shapeRatio,
SmallVector<OpFoldResult> offsets,
SmallVector<int64_t> &supportedShape,
TypedValue<xegpu::TensorDescType> newTensorDesc,
@@ -173,6 +181,11 @@ static Value generateLoads(ConversionPatternRewriter &rewriter,
assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
Value offsetX = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
Value offsetY = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+ // Compute the ratio between original shape and supported shape. We need to
+ // generate loads in this ratio arrangement.
+ auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+ supportedShape)
+ .value(); // `ratio` must be defined if we reach here.
for (int64_t h = 0; h < shapeRatio[0]; ++h) {
for (int64_t w = 0; w < shapeRatio[1]; ++w) {
int64_t localOffsetX = h * supportedShape[0];
@@ -214,7 +227,7 @@ class XeGPUCreateNdDescOpPattern final
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tdescTy = createNdOp.getType();
- auto convertType = tryConvertToTransposable(tdescTy);
+ auto convertType = tryOptimize(tdescTy);
if (convertType == tdescTy)
return failure();
auto strides = createNdOp.getMixedStrides();
@@ -291,17 +304,10 @@ class XeGPULoadNdDescOpPattern final
// Get the 2D data shape of this loadNdOp in its original type including
// array length.
SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
- // origDataShape.back() *= origTensorDescType.getArrayLength();
// Adjust the data shape based on innerLaneData.
origDataShape.back() /= innerLaneData;
// HW supported shape is the new tensor desc shape after conversion.
SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
- // Shape ratio is 2D and, it describes how many blocks need to be loaded in
- // HW supported shape to cover the original shape.
- auto ratio = computeShapeRatio(origDataShape, hwSupportedShape)
- .value(); // `ratio` must be defined if we reach here.
- // Create a zero-initialized vector to hold all loaded blocks.
- // TypedAttr zeroAttr = rewriter.getZeroAttr(adaptorType.getElementType());
VectorType origVectorType =
VectorType::get(origDataShape, adaptorType.getElementType());
Value data;
@@ -322,8 +328,8 @@ class XeGPULoadNdDescOpPattern final
i * origDataShape[1]))
.getResult();
slice = generateLoads(
- rewriter, cast<TypedValue<VectorType>>(slice), ratio,
- modifiedOffsets, hwSupportedShape,
+ rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+ hwSupportedShape,
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
loadNdOp);
// BitCast back to original load shape without array length.
@@ -344,7 +350,7 @@ class XeGPULoadNdDescOpPattern final
VectorType::get(origDataShape, adaptorType.getElementType()),
rewriter.getZeroAttr(origVectorType));
data = generateLoads(
- rewriter, cast<TypedValue<VectorType>>(data), ratio, modifiedOffsets,
+ rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
hwSupportedShape,
cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
loadNdOp);
More information about the Mlir-commits
mailing list