[Mlir-commits] [mlir] [mlir][XeGPU][Transform] Add vectorlinearize transform pass. (PR #158084)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Thu Sep 18 15:09:47 PDT 2025


================
@@ -0,0 +1,120 @@
+//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-vector-linearize"
+
+using namespace mlir;
+
+namespace {
+struct XeGPUVectorLinearizePass final
+    : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
+  void runOnOperation() override {
+    // vector.broadcast and vector.gather requires progressive lowering
+    {
+      RewritePatternSet patterns(&getContext());
+      vector::populateVectorBroadcastLoweringPatterns(patterns);
+      vector::populateVectorGatherLoweringPatterns(patterns);
+      vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+      // vector.transpose lowering
+      // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
+      vector::populateVectorTransposeLoweringPatterns(
+          patterns, vector::VectorTransposeLowering::Shuffle16x16);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
+    // <1x1x...x1xdk>.
+    {
+      RewritePatternSet patterns(&getContext());
+      vector::UnrollVectorOptions vectorOptions;
+      vectorOptions.setNativeShapeFn(
+          [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+            auto extractVectorType = [](Operation *op) -> VectorType {
+              if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+                return loadOp.getVectorType();
+              if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+                return storeOp.getVectorType();
+              return nullptr;
+            };
+
+            VectorType vecType = extractVectorType(op);
+            if (!vecType)
+              return std::nullopt;
+
+            // Only handle rank >= 2 so we actually unroll something.
+            int64_t rank = vecType.getRank();
+            if (rank < 2)
+              return std::nullopt;
+
+            ArrayRef<int64_t> shape = vecType.getShape();
+            // Bail if any of the (rank-1) leading dims are dynamic (can't fully
+            // unroll).
+            for (int64_t i = 0; i < rank - 1; ++i)
+              if (shape[i] == ShapedType::kDynamic) {
----------------
mshahneo wrote:

Fixed.
Sorry, I think I saw some code comment somewhere with vector<?x>, but I guess I was wrong.

https://github.com/llvm/llvm-project/pull/158084


More information about the Mlir-commits mailing list