[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns and blocking pass for XeGPU [2/N] (PR #140163)
Chao Chen
llvmlistbot at llvm.org
Tue May 27 08:28:01 PDT 2025
================
@@ -0,0 +1,337 @@
+//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUBLOCKING
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-blocking"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// reslove the unrealized conversion cast ops generated when doing SCF
+// Structural Type Conversion. It will have two formats, N:1 vector
+// cast and 1:N vector cast. vector::insert_strided_slice ops will be
+// used for the first case, and vector::extract_strided_slice ops will be
+// used for the second case.
+static void
+resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
+ ValueRange inputs = castOp.getInputs();
+ ValueRange outputs = castOp.getOutputs();
+
+ if (inputs.size() == 1 && outputs.size() == 1) {
+ castOp->replaceAllUsesWith(inputs);
+ castOp->erase();
+ }
+
+ VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
+ VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
+ if (inputTy && outputTy) {
+ OpBuilder builder(castOp);
+ // unpack
+ if (inputs.size() > 1 && outputs.size() == 1) {
+ ArrayRef<int64_t> shape = outputTy.getShape();
+ Value result = xegpu::createVectorWithShapeFromValues(
+ builder, castOp.getLoc(), inputs, shape);
+ castOp->replaceAllUsesWith(ValueRange(result));
+ castOp->erase();
+ }
+
+ // pack
+ if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+ ArrayRef<int64_t> tileShape = outputTy.getShape();
+ SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
+ builder, castOp.getLoc(), inputs[0], tileShape);
+ castOp->replaceAllUsesWith(results);
+ castOp->erase();
+ }
+ }
+}
+
+/// Unroll XeGPU ops to their instruction-level representation.
+class XeGPUBlockingPass final
+ : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
+public:
+ void runOnOperation() override;
+
+private:
+ // Get the tile shape for a given operand by examining the layout attribute.
+ // If layout is not present or is not a subgroup level layout, it returns
+ // std::nullopt.
+ std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;
+
+ // Get the tile shape for a given result by examining the layout attribute.
+ // If layout is not present or is not a subgroup level layout, it returns
+ // std::nullopt.
+ std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;
+
+ // Get the tile shape for a given operation.
+ std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
+
+ // Determine if the operation requires unrolling. Return false if all operands
+ // and results have tile shapes identical to their original types. Otherwise,
+ // return true.
+ bool needsUnroll(Operation *op) const;
+};
+} // namespace
+
+std::optional<SmallVector<int64_t>>
+XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+ if (layout && layout.isSgLayout()) {
+ if (auto inst_data = layout.getInstData())
+ return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+
+ if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
+ return llvm::to_vector(type.getShape());
+ }
+ LDBG("failed to getTileShape for operand: " << operand.get());
+ return std::nullopt;
+}
+
+std::optional<SmallVector<int64_t>>
+XeGPUBlockingPass::getTileShape(OpResult result) const {
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+ if (layout && layout.isSgLayout()) {
+ if (auto inst_data = layout.getInstData())
+ return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+
+ if (auto type = dyn_cast<ShapedType>(result.getType()))
+ return llvm::to_vector(type.getShape());
+ }
+ LDBG("failed to getTileShape for result: " << result);
+ return std::nullopt;
+}
+
+std::optional<SmallVector<int64_t>>
+XeGPUBlockingPass::getTileShape(Operation *op) const {
+ if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+ return getTileShape(op->getOpResult(0));
+ if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+ return getTileShape(op->getOpOperand(0));
+ if (isa<xegpu::StoreNdOp>(op))
+ return getTileShape(op->getOpOperand(1));
+
+ if (isa<xegpu::DpasOp>(op)) {
+ std::optional<SmallVector<int64_t>> aTile =
+ getTileShape(op->getOpOperand(0));
+ std::optional<SmallVector<int64_t>> bTile =
+ getTileShape(op->getOpOperand(1));
+
+ if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
+ return std::nullopt;
+
+ // semantic check for A and B
+ if ((*aTile)[1] != (*bTile)[0])
+ return std::nullopt;
+
+ // semantic check for C
+ if (op->getNumOperands() == 3) {
+ std::optional<SmallVector<int64_t>> cTile =
+ getTileShape(op->getOpOperand(2));
+ int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
+ if (!cTile || !llvm::equal(*cTile, expectedCTile))
+ return std::nullopt;
+ }
+
+ return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
+ }
+
+ if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
+ return getTileShape(op->getOpResult(0));
+
+ return std::nullopt;
+}
+
+bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
+ if (isa<LoopLikeOpInterface>(op))
+ return false;
+
+ auto isUnrollable = [&](Value value,
+ ArrayRef<int64_t> tileShape) -> std::optional<bool> {
----------------
chencha3 wrote:
it means `let rest operands or results` to determine. The logic is if an op contains an operand or result with wg layout, it is considered an invalid op to lower. Otherwise, if the op contains a sg layout with inst_data field, it needs to be lowered. I will think about it again and try to make this logic clearer
https://github.com/llvm/llvm-project/pull/140163
More information about the Mlir-commits
mailing list