[Mlir-commits] [mlir] 9703bda - [mlir][xegpu] Add OptimizeBlockLoads pass. (#165483)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 4 13:15:36 PST 2025
Author: Charitha Saumya
Date: 2025-11-04T13:15:32-08:00
New Revision: 9703bda95b088bb6a455ef9faffdb41c537aff2f
URL: https://github.com/llvm/llvm-project/commit/9703bda95b088bb6a455ef9faffdb41c537aff2f
DIFF: https://github.com/llvm/llvm-project/commit/9703bda95b088bb6a455ef9faffdb41c537aff2f.diff
LOG: [mlir][xegpu] Add OptimizeBlockLoads pass. (#165483)
This pass rewrites certain xegpu `CreateNd` and `LoadNd` operations that
feeds into `vector.transpose` to more optimal form to improve
performance. Specifically, low precision (bitwidth < 32) `LoadNd` ops
that feeds into transpose ops are rewritten to i32 loads with a valid
transpose layout such that later passes can use the load with transpose
HW feature to accelerate such load ops.
**Update:**
Pass is renamed to `OptimizeBlockLoads ` because later we plan to add
the array length optimization into this pass as well. This will break
down a larger load (like `32x32xf16`) into more DPAS-favorable array
length loads (`32x16xf16` with array length = 2). Both these
optmizations require rewriting `CreateNd` and `LoadNd` and it makes
sense to have a common pass for both.
Added:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
mlir/test/Dialect/XeGPU/optimize-transpose.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index eb05628d4772b..e42799689e490 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -85,4 +85,16 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
}
+def XeGPUOptimizeBlockLoads : Pass<"xegpu-optimize-block-loads"> {
+ let summary = "Optimize XeGPU block load operations";
+ let description = [{
+ This pass rewrites XeGPU loadNd operations into more optimal forms
+ to improve performance. This includes,
+ - Rewriting transpose B loads into more optimal forms to use HW block
+ transpose instructions for better performance.
+ }];
+ let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+ "vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index a480195eebd00..1776a209d0bf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -61,7 +61,8 @@ struct UnrollOptions {
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
-
+/// Appends patterns for optimizing block load operations into `patterns`.
+void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns);
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 620a2fe43d682..58092c3bb9ed2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -166,6 +166,15 @@ SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> lhs,
ArrayRef<OpFoldResult> rhs);
+
+/// Helper Function to find a proper instruction multiple for the user-supplied
+/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
+/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
+/// array length).
+template <typename T>
+int getLargestDivisor(T dim, ArrayRef<T> candidates,
+ ArrayRef<T> candidateMultiples = {});
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index e6f76067094ce..29b645feab2c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
XeGPUVectorLinearize.cpp
+ XeGPUOptimizeBlockLoads.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
new file mode 100644
index 0000000000000..4dc5ea4f7bb24
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -0,0 +1,490 @@
+//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Types.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-optimize-block-loads"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+/// Get the 2D lane data from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneData(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ if (laneData.size() != 2)
+ return std::nullopt;
+ return laneData;
+}
+
+/// Get the 2D lane layout from a tensor desc type if it exists.
+static std::optional<SmallVector<int64_t>>
+getMaybeLaneLayout(xegpu::TensorDescType tdescType) {
+ auto layout = tdescType.getLayoutAttr();
+ if (!layout)
+ return std::nullopt;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (laneLayout.size() != 2)
+ return std::nullopt;
+ return laneLayout;
+}
+
+/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 &&
+/// lane[1] == 1), but inner lane data is not equal to [1, 1].
+/// Example:
+/// !xegpu.tensor_desc<16x16xf16,
+/// #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form)
+/// indicating that this is a load that requires transpose effect. However,
+/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from
+/// the inner dimension. We convert this to a optimized form by converting the
+/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the
+/// later lowering easily use the load with transpose instruction.
+static bool canBeOptimizedForTranspose(ArrayRef<int64_t> laneLayout,
+ ArrayRef<int64_t> laneData) {
+ if (laneLayout.size() != 2 || laneData.size() != 2)
+ return false;
+ if (laneLayout[0] == 1 || laneLayout[1] != 1)
+ return false;
+ if (laneData[0] != 1 || laneData[1] == 1)
+ return false;
+ return true;
+}
+
+/// A tensor desc type can be optimized if its element type is less than 32 bits
+/// and its layout can be optimized.
+static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) {
+ // If the dtype is greater or equal to 32 bits, layout must be valid.
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ if (elementTyBitwidth >= 32)
+ return false;
+ auto maybeLaneLayout = getMaybeLaneLayout(tdescType);
+ auto maybeLaneData = getMaybeLaneData(tdescType);
+ if (!maybeLaneData || !maybeLaneLayout)
+ return false;
+ return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData);
+}
+
+/// Check if a tensor desc type can be optimized for transpose, if so return the
+/// new optimized tensor desc type with a valid transpose layout.
+static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
+ const uArch *targetuArch) {
+ if (!canBeOptimizedForTranspose(tdescType))
+ return tdescType;
+ auto laneData = getMaybeLaneData(tdescType)
+ .value(); // Lane data must exist if we reach here.
+ int64_t innerLaneData = laneData[1];
+ int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth();
+ // Required shape is total shape of the vector result that this tensor desc
+ // must eventually load after adjusting for the new bitwidth and array
+ // length.
+ SmallVector<int64_t> requiredShape(tdescType.getShape());
+ requiredShape.back() =
+ requiredShape.back() * tdescType.getArrayLength() / innerLaneData;
+ int newBitWidth = elementTyBitwidth * innerLaneData;
+ Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth);
+ // Supported shape is the max transpose shape that can be supported by
+ // hardware that is less than or equal to required shape.
+ auto *blockLoadTarget = dyn_cast<Subgroup2DBlockLoadInstruction>(
+ targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad));
+ auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount(
+ newElemTy, /** has transform */ false, /** has transpose */ true);
+ // If no HW params found, return the original type.
+ if (!maybeHWParams)
+ return tdescType;
+ auto [widths, heights, counts] = maybeHWParams.value();
+ // TODO: Currently we expect array length to be 1 for transpose case.
+ if (counts.size() != 1 || counts[0] != 1)
+ return tdescType;
+ int arrayLen = counts[0];
+ int supportedHeight =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[0]), heights);
+ int supportedWidth =
+ xegpu::getLargestDivisor(static_cast<int>(requiredShape[1]), widths);
+ // If no supported height or width found, return the original type.
+ if (supportedHeight == -1 || supportedWidth == -1)
+ return tdescType;
+
+ SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+ xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
+ tdescType.getContext(),
+ tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
+ // Array length can not be larger than 1 for transpose case.
+ return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
+ tdescType.getBoundaryCheck(),
+ tdescType.getMemorySpace(), newLayout);
+}
+
+/// Helper to convert an OpFoldResult to Value.
+static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc,
+ OpFoldResult ofr) {
+ std::optional<int64_t> mayBeInt = getConstantIntValue(ofr);
+ if (mayBeInt)
+ return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult();
+ return llvm::cast<Value>(ofr);
+}
+
+/// Helper to divide a Value by a constant integer.
+static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc,
+ Value val, int64_t constant) {
+ // If the constant is a power of 2, use right shift for division.
+ if (llvm::isPowerOf2_64(constant)) {
+ int64_t shiftAmount = llvm::Log2_64(constant);
+ return arith::ShRUIOp::create(
+ rewriter, loc, val,
+ arith::ConstantIndexOp::create(rewriter, loc, shiftAmount)
+ .getResult())
+ .getResult();
+ }
+ auto constantOp =
+ arith::ConstantIndexOp::create(rewriter, loc, constant).getResult();
+ return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult();
+}
+
+/// This function takes a larger register block `data` and generates multiple
+/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block
+/// starting from `offsets`.
+static Value generateLoads(ConversionPatternRewriter &rewriter,
+ TypedValue<VectorType> data,
+ SmallVector<OpFoldResult> offsets,
+ TypedValue<xegpu::TensorDescType> newTensorDesc,
+ xegpu::LoadNdOp origLoadOp) {
+ Location loc = data.getLoc();
+ assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp");
+ Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]);
+ Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]);
+ SmallVector<int64_t> supportedShape(newTensorDesc.getType().getShape());
+ // Compute the ratio between original shape and supported shape. We need to
+ // generate loads in this ratio arrangement.
+ auto shapeRatio = computeShapeRatio(data.getType().getShape(),
+ supportedShape)
+ .value(); // `ratio` must be defined if we reach here.
+ for (int64_t h = 0; h < shapeRatio[0]; ++h) {
+ for (int64_t w = 0; w < shapeRatio[1]; ++w) {
+ int64_t localOffsetDim0 = h * supportedShape[0];
+ int64_t localOffsetDim1 = w * supportedShape[1];
+ Value loadOffsetX = arith::AddIOp::create(
+ rewriter, loc, offsetDim0,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0)
+ .getResult());
+ Value loadOffsetY = arith::AddIOp::create(
+ rewriter, loc, offsetDim1,
+ arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1)
+ .getResult());
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc,
+ VectorType::get(supportedShape, data.getType().getElementType()),
+ newTensorDesc, ArrayRef<OpFoldResult>{loadOffsetX, loadOffsetY},
+ origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(),
+ origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(),
+ origLoadOp.getL3HintAttr());
+ // Set the layout for the loadOp.
+ auto layoutAttr = newTensorDesc.getType().getLayoutAttr();
+ xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr);
+ // Insert the loaded block into the right position in data.
+ auto insertOp = vector::InsertStridedSliceOp::create(
+ rewriter, loc, loadOp.getResult(), data,
+ ArrayRef<int64_t>{localOffsetDim0, localOffsetDim1},
+ ArrayRef<int64_t>{1, 1});
+ // InsertOp must have the same layout as newTensorDesc.
+ xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr);
+ data = insertOp.getResult();
+ }
+ }
+ return data;
+}
+
+/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a
+/// new CreateNdDescOp with optimized tensor desc type. This involves extracting
+/// the base pointer from the original memory source and adjusting the shape and
+/// strides of the tensor desc to fit with the new optimized transpose layout.
+class XeGPUCreateNdDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tdescTy = createNdOp.getType();
+ // Get the target uArch info.
+ auto chipStr = xegpu::getChipStr(createNdOp);
+ // Check if the chip is supported.
+ assert(
+ chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") &&
+ "Expecting target chip to be pvc or bmg for transpose optimization.");
+ const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value());
+
+ auto convertType = tryOptimize(tdescTy, targetuArch);
+ if (convertType == tdescTy)
+ return failure();
+ auto strides = createNdOp.getMixedStrides();
+ auto maybeConstInnerStride = getConstantIntValue(strides.back());
+ // Only row-major memrefs are expected for now.
+ if (!maybeConstInnerStride || *maybeConstInnerStride != 1)
+ return rewriter.notifyMatchFailure(
+ createNdOp, "Expecting row-major memref for transpose optimization.");
+ Value source = createNdOp.getSource();
+ auto optionalLaneData = getMaybeLaneData(tdescTy);
+ assert(optionalLaneData && "Expected 2D lane data");
+ auto laneData = optionalLaneData.value();
+ int64_t innerLaneData = laneData[1];
+ auto memrefType = dyn_cast<MemRefType>(source.getType());
+ // Inner dimension of the shape must be adjusted based on innerLaneData.
+ SmallVector<OpFoldResult> modifiedShape(createNdOp.getMixedSizes());
+ modifiedShape.back() = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()),
+ innerLaneData);
+ // Similarly, second to last stride must be adjusted.
+ assert(strides.size() >= 2 &&
+ "Expected at least 2 strides for CreateNdDescOp");
+ SmallVector<OpFoldResult> modifiedStrides(strides);
+ modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
+ rewriter, createNdOp.getLoc(),
+ convertToValue(rewriter, createNdOp.getLoc(),
+ modifiedStrides[modifiedStrides.size() - 2]),
+ innerLaneData);
+
+ // If the source is a static memref, we need to extract the pointer to
+ // base address.
+ if (memrefType && memrefType.hasStaticShape()) {
+ auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, createNdOp.getLoc(), source);
+ source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
+ rewriter.getI64Type(),
+ extractOp.getResult())
+ .getResult();
+ }
+ // Create a new CreateNdDescOp with the modified shape and converted type.
+ auto newCreateNdDescOp = xegpu::CreateNdDescOp::create(
+ rewriter, createNdOp.getLoc(), convertType, source, modifiedShape,
+ modifiedStrides);
+ rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult());
+ return success();
+ }
+};
+
+/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for
+/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the
+/// adjusted tensor desc type. This can result in multiple LoadNdOps being
+/// generated to fill in the original load shape.
+class XeGPULoadNdDescOpPattern final
+ : public OpConversionPattern<xegpu::LoadNdOp> {
+public:
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto origTensorDescType = loadNdOp.getTensorDescType();
+ auto adaptorType =
+ cast<xegpu::TensorDescType>(adaptor.getTensorDesc().getType());
+ if (adaptorType == origTensorDescType)
+ return failure();
+ // Offsets must be adjusted based on innerLaneData.
+ auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value();
+ int64_t innerLaneData = laneData[1];
+ auto offsets = loadNdOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(loadNdOp,
+ "Expecting offsets in LoadNd");
+ SmallVector<OpFoldResult> modifiedOffsets(offsets);
+ modifiedOffsets.back() = divideByConstant(
+ rewriter, loadNdOp.getLoc(),
+ convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()),
+ innerLaneData);
+ // Get the 2D data shape of this loadNdOp in its original type including
+ // array length.
+ SmallVector<int64_t> origDataShape(origTensorDescType.getShape());
+ // Adjust the data shape based on innerLaneData.
+ origDataShape.back() /= innerLaneData;
+ // HW supported shape is the new tensor desc shape after conversion.
+ SmallVector<int64_t> hwSupportedShape(adaptorType.getShape());
+ VectorType origVectorType =
+ VectorType::get(origDataShape, adaptorType.getElementType());
+ Value data;
+ // Orig data shape is 3D for the array length case.
+ if (origTensorDescType.getArrayLength() > 1) {
+ SmallVector<Value> arraySlices;
+ for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) {
+ Value slice = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(), origVectorType,
+ rewriter.getZeroAttr(origVectorType));
+ // Increase the Y offset for each array slice.
+ Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(),
+ modifiedOffsets.back());
+ modifiedOffsets.back() =
+ arith::AddIOp::create(
+ rewriter, loadNdOp->getLoc(), offsetY,
+ arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(),
+ i * origDataShape[1])
+ .getResult())
+ .getResult();
+ slice = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(slice), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ // BitCast back to original load shape without array length.
+ auto bitcastType = VectorType::get(origTensorDescType.getShape(),
+ origTensorDescType.getElementType());
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ bitcastType, slice);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ arraySlices.push_back(bitCastOp.getResult());
+ }
+ rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices});
+ return success();
+ }
+ data = arith::ConstantOp::create(
+ rewriter, loadNdOp->getLoc(),
+ VectorType::get(origDataShape, adaptorType.getElementType()),
+ rewriter.getZeroAttr(origVectorType));
+ data = generateLoads(
+ rewriter, cast<TypedValue<VectorType>>(data), modifiedOffsets,
+ cast<TypedValue<xegpu::TensorDescType>>(adaptor.getTensorDesc()),
+ loadNdOp);
+ auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(),
+ loadNdOp.getType(), data);
+ // BitCastOp must have the same layout as the original loadNdOp.
+ xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0),
+ origTensorDescType.getLayoutAttr());
+ rewriter.replaceOp(loadNdOp, bitCastOp);
+ return success();
+ }
+};
+
+/// Vector ExtractOp must be processed if the original tensor desc type has
+/// array length greater than 1. In this case, the LoadNdOp is replaced with
+/// multiple LoadNdOps for each array slice making the extraction unnecessary.
+/// In this case, we simply remove the ExtractOp.
+class VectorExtractOpPattern final
+ : public OpConversionPattern<vector::ExtractOp> {
+public:
+ using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check if the source of the extraction is split to multiple values.
+ if (adaptor.getSource().size() == 1)
+ return failure();
+ auto mixedPos = extractOp.getMixedPosition();
+ if (mixedPos.size() != 1)
+ return failure();
+ auto mayBeInt = getConstantIntValue(mixedPos[0]);
+ if (!mayBeInt)
+ return failure();
+ rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]);
+ return success();
+ }
+};
+
+} // namespace
+
+void xegpu::populateXeGPUOptimizeBlockLoadsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<XeGPUCreateNdDescOpPattern, XeGPULoadNdDescOpPattern,
+ VectorExtractOpPattern>(patterns.getContext());
+}
+
+namespace {
+
+struct XeGPUOptimizeBlockLoadsPass final
+ : public xegpu::impl::XeGPUOptimizeBlockLoadsBase<
+ XeGPUOptimizeBlockLoadsPass> {
+ void runOnOperation() override {
+ MLIRContext &context = getContext();
+ TypeConverter converter;
+ RewritePatternSet patterns(&context);
+ ConversionTarget target(context);
+
+ // This pass is only meant for PVC and BMG targets. If unsupported target
+ // is found, exit early.
+ bool isTargetSupported = false;
+ getOperation()->walk([&](gpu::GPUFuncOp funcOp) {
+ auto chipStr = xegpu::getChipStr(funcOp);
+ if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg"))
+ isTargetSupported = true;
+ });
+
+ if (!isTargetSupported) {
+ DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets."
+ << "\n";
+ return;
+ }
+
+ // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be
+ // converted.
+ target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+ [&](xegpu::CreateNdDescOp createNdOp) {
+ return !canBeOptimizedForTranspose(createNdOp.getType());
+ });
+ target.addDynamicallyLegalOp<xegpu::LoadNdOp>(
+ [&](xegpu::LoadNdOp loadNdOp) {
+ return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType());
+ });
+ // Vector ExtractOps can have optimizable layouts if they extract from
+ // LoadNdOps with array length greater than 1. These ExtractOps must be
+ // converted.
+ target.addDynamicallyLegalOp<vector::ExtractOp>(
+ [&](vector::ExtractOp extractOp) {
+ auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult());
+ if (!layout)
+ return true;
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ return !canBeOptimizedForTranspose(laneLayout, laneData);
+ });
+ converter.addConversion([](Type type) { return type; });
+
+ target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
+ vector::VectorDialect>();
+ scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
+ xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ DBGS() << "Optimize block loads pass failed.\n";
+ return signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 14c49e7f45706..4e1a539771d2f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -204,28 +204,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
using Lattice::Lattice;
};
-/// Helper Function to find a proper instruction multiple for the user-supplied
-/// sg-level data shape. `candidates` are uArch allowed shapes.
-/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
-template <typename T>
-int getLargestDivisor(T dim, ArrayRef<T> candidates,
- ArrayRef<T> candidateMultiples = {}) {
- static_assert(std::is_integral<T>::value, "T must be an integer type");
- int largest = -1;
- SmallVector<T> multiples = {1};
- if (!candidateMultiples.empty())
- multiples =
- SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
- for (T candidate : candidates) {
- for (T multiple : multiples) {
- int value = static_cast<int>(candidate * multiple);
- if (value != 0 && dim % value == 0 && value > largest)
- largest = value;
- }
- }
- return largest;
-}
-
/// Helper Functions to get default layouts. A `default layout` is a layout that
/// is assigned to a value when the layout is not fixed by some anchor operation
/// (like DPAS).
@@ -505,7 +483,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
prefetch.emitWarning("No known block params found for the element type.");
auto [bWidth, bHeight, bCount] = blockWHC.value();
SmallVector<int> instData;
- int instWidth = getLargestDivisor(
+ int instWidth = xegpu::getLargestDivisor(
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
bCount);
if (instWidth == -1)
@@ -514,7 +492,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
if (tdescTy.getRank() == 1)
instData = {instWidth};
else {
- int instHeight = getLargestDivisor(
+ int instHeight = xegpu::getLargestDivisor(
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
if (instHeight == -1)
prefetch.emitWarning(
@@ -634,7 +612,7 @@ void LayoutInfoPropagation::visitDpasOp(
const unsigned dataALen = aTy.getShape().front();
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
const int maxALen =
- getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
if (maxALen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
@@ -642,7 +620,7 @@ void LayoutInfoPropagation::visitDpasOp(
const unsigned dataBLen = bTy.getShape().back();
auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
const int maxBLen =
- getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
if (maxBLen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
@@ -662,7 +640,7 @@ void LayoutInfoPropagation::visitDpasOp(
const unsigned dataCLen = bTy.getShape().back();
auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
const int maxCLen =
- getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
if (maxCLen == -1)
dpas.emitWarning(
"No suitable instruction multiple found for the given shape.");
@@ -691,7 +669,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
store.emitWarning("No known block params found for the element type.");
auto [bWidth, bHeight, bCount] = blockWHC.value();
SmallVector<int> instData;
- int instWidth = getLargestDivisor(
+ int instWidth = xegpu::getLargestDivisor(
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
bCount);
if (instWidth == -1)
@@ -700,7 +678,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
if (dataTy.getRank() == 1)
instData = {instWidth};
else {
- int instHeight = getLargestDivisor(
+ int instHeight = xegpu::getLargestDivisor(
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
if (instHeight == -1)
store.emitWarning(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d575a415a3035..de9e09d427665 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -555,3 +555,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
results.append(addElementwise(builder, loc, a, b));
return results;
}
+
+template <typename T>
+int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
+ ArrayRef<T> candidateMultiples) {
+ static_assert(std::is_integral<T>::value, "T must be an integer type");
+ int largest = -1;
+ SmallVector<T> multiples = {1};
+ if (!candidateMultiples.empty())
+ multiples =
+ SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
+ for (T candidate : candidates) {
+ for (T multiple : multiples) {
+ int value = static_cast<int>(candidate * multiple);
+ if (value != 0 && dim % value == 0 && value > largest)
+ largest = value;
+ }
+ }
+ return largest;
+}
+
+/// Explicit instantiations
+template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
+ ArrayRef<int> candidateMultiples);
+template int
+xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
+ ArrayRef<unsigned> candidateMultiples);
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
new file mode 100644
index 0000000000000..24a0de6ed48a5
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
@@ -0,0 +1,280 @@
+// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' \
+// RUN: --xegpu-optimize-block-loads --canonicalize --cse --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: gpu.func @no_scf(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+ %c0 = arith.constant 0 : index
+ %c32 = arith.constant 32 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+ gpu.return %6 : vector<8x16xf32>
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @no_scf_i8(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> {
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index
+// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>} : vector<16x8xi32> to vector<16x32xi8>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
+#c = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+gpu.module @xevm_module {
+gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b>
+ %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8>
+ %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8>
+ %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32>
+ gpu.return %6 : vector<8x16xi32>
+}
+}
+
+
+// -----
+// CHECK-LABEL: gpu.func @gemm_b_transpose(
+// CHECK-SAME: %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1]
+// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @nested_scf(
+// CHECK-SAME: %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %{{.*}} to %{{.*}} step %{{.*}} {
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) {
+// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] {layout_result_0 = #xegpu.layout<
+// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c256 = arith.constant 256 : index
+ scf.for %arg8 = %c0 to %c256 step %c16 {
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+ %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
+ %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16>
+ %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+ %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %8 : vector<8x16xf32>
+ } {layout_result_0 = #a}
+ xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ }
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @large_loads(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64
+// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, offsets = [0, 8], strides = [1, 1]}
+// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32>
+// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+ %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x32xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
+
+// -----
+// CHECK-LABEL: gpu.func @array_length(
+// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index
+// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 ->
+// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index
+// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
+// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
+// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}
+// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<32x8xi32>
+// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c256 = arith.constant 256 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a>
+ %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16>
+ -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
+ %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1)
+ -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) {
+ %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b }
+ : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x32x16xf16>
+ %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+ %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16>
+ %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b }
+ : vector<32x16xf16> to vector<16x16xf16>
+ %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+ %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+ scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>
+ } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a}
+ xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a>
+ gpu.return
+}
+}
More information about the Mlir-commits
mailing list