[Mlir-commits] [mlir] [mlir][VectorToLLVM] Add support for unrolling and lowering multi-dim… (PR #160405)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 24 03:14:43 PDT 2025
https://github.com/tyb0807 updated https://github.com/llvm/llvm-project/pull/160405
>From 610c8c60549c48258e3df4cc62f20c625375f313 Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Tue, 23 Sep 2025 23:48:50 +0200
Subject: [PATCH 1/2] [mlir][VectorToLLVM] Add support for unrolling and
lowering multi-dimensional vector.scatter operations
---
.../Vector/Transforms/LoweringPatterns.h | 8 ++
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 3 +-
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../Dialect/Vector/Transforms/CMakeLists.txt | 1 +
.../Vector/Transforms/LowerVectorScatter.cpp | 101 ++++++++++++++++++
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 46 ++++++--
.../VectorToLLVM/vector-to-llvm.mlir | 15 ++-
7 files changed, 158 insertions(+), 17 deletions(-)
create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 7bd96c8a6d1a1..83c08d4103177 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -239,6 +239,14 @@ void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorScatterLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [UnrollGather]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index a57aadcdcc5b0..c6a1c62afc92f 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -254,7 +254,8 @@ using UnrollVectorOpFn =
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
- UnrollVectorOpFn unrollFn);
+ UnrollVectorOpFn unrollFn,
+ VectorType vectorTy = nullptr);
/// Generic utility for unrolling values of type vector<NxAxBx...>
/// to N values of type vector<AxBx...> using vector.extract. If the input
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index f958edf2746e9..1750e0430fd2b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorScatterLoweringPatterns(patterns);
populateVectorFromElementsUnrollPatterns(patterns);
populateVectorToElementsUnrollPatterns(patterns);
if (armI8MM) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4e0f07af95984..cd7d4f5d1c69c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
LowerVectorScan.cpp
+ LowerVectorScatter.cpp
LowerVectorShapeCast.cpp
LowerVectorShuffle.cpp
LowerVectorStep.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
new file mode 100644
index 0000000000000..d236c2d23b3b9
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
@@ -0,0 +1,101 @@
+//===- LowerVectorScatter.cpp - Lower 'vector.scatter' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.scatter' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-scatter-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// vector.scatter %base[%c0][%idx], %mask, %value :
+/// memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+///
+/// ==>
+///
+/// %v0 = vector.extract %value[0] : vector<3xf32> from vector<2x3xf32>
+/// %m0 = vector.extract %mask[0] : vector<3xi1> from vector<2x3xi1>
+/// %i0 = vector.extract %idx[0] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i0], %m0, %v0 :
+/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+///
+/// %v1 = vector.extract %value[1] : vector<3xf32> from vector<2x3xf32>
+/// %m1 = vector.extract %mask[1] : vector<3xi1> from vector<2x3xi1>
+/// %i1 = vector.extract %idx[1] : vector<3xi32> from vector<2x3xi32>
+/// vector.scatter %base[%c0][%i1], %m1, %v1 :
+/// memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ Value indexVec = op.getIndices();
+ Value maskVec = op.getMask();
+ Value valueVec = op.getValueToStore();
+
+ // Get the vector type from one of the vector operands
+ VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
+ if (!vectorTy)
+ return failure();
+
+ auto unrollScatterFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
+
+ Value indexSubVec =
+ vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
+ Value maskSubVec =
+ vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
+ Value valueSubVec =
+ vector::ExtractOp::create(rewriter, loc, valueVec, thisIdx);
+
+ rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getOffsets(),
+ indexSubVec, maskSubVec, valueSubVec,
+ op.getAlignmentAttr());
+
+ // Return a dummy value since unrollVectorOp expects a Value
+ return rewriter.create<ub::PoisonOp>(loc, subTy);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorScatterLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollScatter>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 025ee9a04a1de..53ac3d50e1d21 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -431,27 +431,51 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
}
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
- vector::UnrollVectorOpFn unrollFn) {
- assert(op->getNumResults() == 1 && "expected single result");
- assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
- VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
- if (resultTy.getRank() < 2)
+ vector::UnrollVectorOpFn unrollFn,
+ VectorType vectorTy) {
+ // If vector type is not provided, get it from the result
+ if (!vectorTy) {
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "expected single result when vector type not provided");
+
+ vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!vectorTy)
+ return rewriter.notifyMatchFailure(op, "expected vector type");
+ }
+
+ if (vectorTy.getRank() < 2)
return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
+ if (vectorTy.getScalableDims().front())
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
Location loc = op->getLoc();
- Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ // Only create result value if the operation produces results
+ Value result;
+ if (op->getNumResults() > 0) {
+ result = ub::PoisonOp::create(rewriter, loc, vectorTy);
+ }
+
+ VectorType subTy = VectorType::Builder(vectorTy).dropDim(0);
+
+ for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
Value subVector = unrollFn(rewriter, loc, subTy, i);
- result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+
+ // Only insert if we have a result to build
+ if (op->getNumResults() > 0) {
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+ }
+
+ if (op->getNumResults() > 0) {
+ rewriter.replaceOp(op, result);
+ } else {
+ rewriter.eraseOp(op);
}
- rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2d33888854ea7..6ba37bd56083f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1643,9 +1643,6 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// vector.scatter
//===----------------------------------------------------------------------===//
-// Multi-Dimensional scatters are not supported yet. Check that we do not lower
-// them.
-
func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
%0 = arith.constant 0: index
%1 = vector.constant_mask [2, 2] : vector<2x3xi1>
@@ -1654,7 +1651,11 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
}
// CHECK-LABEL: func @scatter_with_mask
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// CHECK-NOT: vector.scatter
// -----
@@ -1669,7 +1670,11 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
}
// CHECK-LABEL: func @scatter_with_mask_scalable
-// CHECK: vector.scatter
+// CHECK: llvm.extractvalue {{.*}}[0]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK: llvm.extractvalue {{.*}}[1]
+// CHECK: llvm.intr.masked.scatter {{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
+// CHECK-NOT: vector.scatter
// -----
>From f72fb10ff5d13bc496b4f247a0cb830eea387ac9 Mon Sep 17 00:00:00 2001
From: tyb0807 <sontuan.vu119 at gmail.com>
Date: Wed, 24 Sep 2025 12:12:32 +0200
Subject: [PATCH 2/2] Address comments
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 4 ++++
.../Vector/Transforms/LowerVectorScatter.cpp | 6 +++---
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 15 ++++++---------
3 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index c6a1c62afc92f..8f609acd2fdb7 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -250,6 +250,10 @@ LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
/// create sub vectors.
/// 5. Insert the sub vectors back into the final vector.
/// 6. Replace the original op with the new result.
+///
+/// Expects the operation to be unrolled to have at most 1 result. When there's
+/// no result, expects the caller to pass in the `vectorTy` to be able to get
+/// the unroll factor.
using UnrollVectorOpFn =
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
index d236c2d23b3b9..af17136f21da0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScatter.cpp
@@ -65,7 +65,7 @@ struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
Value maskVec = op.getMask();
Value valueVec = op.getValueToStore();
- // Get the vector type from one of the vector operands
+ // Get the vector type from one of the vector operands.
VectorType vectorTy = dyn_cast<VectorType>(indexVec.getType());
if (!vectorTy)
return failure();
@@ -85,8 +85,8 @@ struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
indexSubVec, maskSubVec, valueSubVec,
op.getAlignmentAttr());
- // Return a dummy value since unrollVectorOp expects a Value
- return rewriter.create<ub::PoisonOp>(loc, subTy);
+ // Return a dummy value since unrollVectorOp expects a Value.
+ return Value();
};
return unrollVectorOp(op, rewriter, unrollScatterFn, vectorTy);
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 53ac3d50e1d21..02f9382c760be 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -433,15 +433,12 @@ vector::unrollVectorValue(TypedValue<VectorType> vector,
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
vector::UnrollVectorOpFn unrollFn,
VectorType vectorTy) {
- // If vector type is not provided, get it from the result
+ // If vector type is not provided, get it from the result.
if (!vectorTy) {
- if (op->getNumResults() != 1)
- return rewriter.notifyMatchFailure(
- op, "expected single result when vector type not provided");
-
+ assert(op->getNumResults() == 1 &&
+ "expected single result when vector type not provided");
vectorTy = dyn_cast<VectorType>(op->getResult(0).getType());
- if (!vectorTy)
- return rewriter.notifyMatchFailure(op, "expected vector type");
+ assert(vectorTy && "expected result to have vector type");
}
if (vectorTy.getRank() < 2)
@@ -454,7 +451,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
Location loc = op->getLoc();
- // Only create result value if the operation produces results
+ // Only create result value if the operation produces results.
Value result;
if (op->getNumResults() > 0) {
result = ub::PoisonOp::create(rewriter, loc, vectorTy);
@@ -465,7 +462,7 @@ LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
Value subVector = unrollFn(rewriter, loc, subTy, i);
- // Only insert if we have a result to build
+ // Only insert if we have a result to build.
if (op->getNumResults() > 0) {
result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
}
More information about the Mlir-commits
mailing list