[Mlir-commits] [mlir] [mlir][XeGPU][Transform] Add vectorlinearize transform pass. (PR #158084)

Md Abdullah Shahneous Bari llvmlistbot at llvm.org
Thu Sep 18 15:07:30 PDT 2025


================
@@ -0,0 +1,120 @@
+//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-vector-linearize"
+
+using namespace mlir;
+
+namespace {
+struct XeGPUVectorLinearizePass final
+    : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
+  void runOnOperation() override {
+    // vector.broadcast and vector.gather requires progressive lowering
+    {
+      RewritePatternSet patterns(&getContext());
+      vector::populateVectorBroadcastLoweringPatterns(patterns);
+      vector::populateVectorGatherLoweringPatterns(patterns);
+      vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+      // vector.transpose lowering
+      // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
+      vector::populateVectorTransposeLoweringPatterns(
+          patterns, vector::VectorTransposeLowering::Shuffle16x16);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+        return signalPassFailure();
+    }
+
+    // Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
+    // <1x1x...x1xdk>.
+    {
+      RewritePatternSet patterns(&getContext());
+      vector::UnrollVectorOptions vectorOptions;
+      vectorOptions.setNativeShapeFn(
+          [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+            auto extractVectorType = [](Operation *op) -> VectorType {
+              if (auto loadOp = dyn_cast<vector::LoadOp>(op))
+                return loadOp.getVectorType();
+              if (auto storeOp = dyn_cast<vector::StoreOp>(op))
+                return storeOp.getVectorType();
+              return nullptr;
+            };
+
+            VectorType vecType = extractVectorType(op);
+            if (!vecType)
+              return std::nullopt;
+
+            // Only handle rank >= 2 so we actually unroll something.
+            int64_t rank = vecType.getRank();
+            if (rank < 2)
+              return std::nullopt;
+
+            ArrayRef<int64_t> shape = vecType.getShape();
+            // Bail if any of the (rank-1) leading dims are dynamic (can't fully
+            // unroll).
+            for (int64_t i = 0; i < rank - 1; ++i)
+              if (shape[i] == ShapedType::kDynamic) {
+                LLVM_DEBUG(llvm::dbgs()
+                           << "Dynamic leading dim " << i << " in " << vecType
+                           << " prevents full unroll.\n");
+                return std::nullopt;
+              }
+
+            // Produce native shape: 1 x 1 x ... x (original last dim).
+            SmallVector<int64_t> native(rank, 1);
+            native.back() = shape.back();
+            return native;
+          });
+      vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+      if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+        LLVM_DEBUG(llvm::dbgs() << "Unroll failed.\n");
----------------
mshahneo wrote:

Thanks a lot, Adam, using LDGB now.

https://github.com/llvm/llvm-project/pull/158084


More information about the Mlir-commits mailing list