[Mlir-commits] [mlir] [mlir][XeGPU][Transform] Add vectorlinearize transform pass. (PR #158084)
Md Abdullah Shahneous Bari
llvmlistbot at llvm.org
Thu Sep 18 15:09:47 PDT 2025
================
@@ -0,0 +1,120 @@
+//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-vector-linearize"
+
+using namespace mlir;
+
+namespace {
+struct XeGPUVectorLinearizePass final
+ : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
+ void runOnOperation() override {
+ // vector.broadcast and vector.gather requires progressive lowering
+ {
+ RewritePatternSet patterns(&getContext());
+ vector::populateVectorBroadcastLoweringPatterns(patterns);
+ vector::populateVectorGatherLoweringPatterns(patterns);
+ vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+ // vector.transpose lowering
+ // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
+ vector::populateVectorTransposeLoweringPatterns(
+ patterns, vector::VectorTransposeLowering::Shuffle16x16);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
+ // <1x1x...x1xdk>.
+ {
+ RewritePatternSet patterns(&getContext());
+ vector::UnrollVectorOptions vectorOptions;
+ vectorOptions.setNativeShapeFn(
+ [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ auto extractVectorType = [](Operation *op) -> VectorType {
+ if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+ return loadOp.getVectorType();
+ if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+ return storeOp.getVectorType();
+ return nullptr;
+ };
+
+ VectorType vecType = extractVectorType(op);
+ if (!vecType)
+ return std::nullopt;
+
+ // Only handle rank >= 2 so we actually unroll something.
+ int64_t rank = vecType.getRank();
+ if (rank < 2)
+ return std::nullopt;
+
+ ArrayRef<int64_t> shape = vecType.getShape();
+ // Bail if any of the (rank-1) leading dims are dynamic (can't fully
+ // unroll).
+ for (int64_t i = 0; i < rank - 1; ++i)
+ if (shape[i] == ShapedType::kDynamic) {
----------------
mshahneo wrote:
Fixed.
Sorry, I think I saw some code comment somewhere with vector<?x>, but I guess I was wrong.
https://github.com/llvm/llvm-project/pull/158084
More information about the Mlir-commits
mailing list