[Mlir-commits] [mlir] [mlir][vector] Allow lowering multi-dim scatters to LLVM (PR #132227)
Kunwar Grover
llvmlistbot at llvm.org
Thu Mar 20 07:53:08 PDT 2025
https://github.com/Groverkss created https://github.com/llvm/llvm-project/pull/132227
This patch adds a UnrollScatter pattern for vector.scatter, exactly same as UnrollGather for vector.gather, allowing us to lower multi dimensional vector.scatter by unrolling to 1-D vectors.
>From bba9af20afac5e39c5cdb08c0685d4723450b926 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Mar 2025 13:27:56 +0000
Subject: [PATCH 1/3] [mlir][vector] Decouple unrolling gather and gather to
llvm lowering
---
.../Vector/Transforms/LoweringPatterns.h | 8 +++-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 45 ++++++------------
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../Vector/Transforms/LowerVectorGather.cpp | 18 +++++---
.../vector-to-llvm-interface.mlir | 46 -------------------
.../VectorToLLVM/vector-to-llvm.mlir | 4 +-
.../Dialect/Vector/TestVectorTransforms.cpp | 1 +
7 files changed, 36 insertions(+), 87 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..77d8b82b2bad0 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -244,13 +244,17 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// [FlattenGather]
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..148c2ae26be35 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,6 +269,10 @@ class VectorGatherOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = gather.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
auto loc = gather->getLoc();
// Resolve alignment.
@@ -276,42 +280,21 @@ class VectorGatherOpConversion
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value base = adaptor.getBase();
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
// Handle the simple case of 1-D vector.
- if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs =
- getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
-
- const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter,
- &typeConverter](Type llvm1DVectorTy,
- ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- SmallVector<Value> vectorOperands = {
- adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Resolve address.
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
+ populateVectorGatherLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..eff8ee0e9de7a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
///
/// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
+ return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, RemoveStrideFromGatherSource,
- Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+ patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
// -----
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
- return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..4a36294b355e7 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = arith.constant 0: index
- %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %2 : vector<2x3xf32>
}
@@ -1679,7 +1679,7 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
// vector.constant_mask only supports 'none set' or 'all set' scalable
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
// width vectors above.
- %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
return %2 : vector<2x[3]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 74838bc0ca2fb..2cf1dde9bd1b8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorGatherToConditionalLoadPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
>From e9b33383fce4c694d8d70ba779af6acc3c5cce8e Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Mar 2025 14:04:28 +0000
Subject: [PATCH 2/3] [mlir][vector] Allow multi dim vectors in vector.scatter
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 12 +++----
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 5 ++-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 +--
.../VectorToLLVM/vector-to-llvm.mlir | 36 ++++++++++++++++++-
mlir/test/Dialect/Vector/invalid.mlir | 2 +-
mlir/test/Dialect/Vector/ops.mlir | 18 +++++-----
6 files changed, 58 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..5fab2ee1194e8 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2034,9 +2034,9 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
- VectorOfRankAndType<[1], [I1]>:$mask,
- VectorOfRank<[1]>:$valueToStore)> {
+ VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
+ VectorOfNonZeroRankOf<[I1]>:$mask,
+ AnyVectorOfNonZeroRank:$valueToStore)> {
let summary = [{
scatters elements from a vector into memory as defined by an index vector
@@ -2044,9 +2044,9 @@ def Vector_ScatterOp :
}];
let description = [{
- The scatter operation stores elements from a 1-D vector into memory as
- defined by a base with indices and an additional 1-D index vector, but
- only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
+ The scatter operation stores elements from a n-D vector into memory as
+ defined by a base with indices and an additional n-D index vector, but
+ only if the corresponding bit in a n-D mask vector is set. Otherwise, no
action is taken for that element. Informally the semantics are:
```
if (mask[0]) base[index[0]] = value[0]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 148c2ae26be35..4127f5b065bc8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -313,13 +313,16 @@ class VectorScatterOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
+ VectorType vType = scatter.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
// Resolve address.
- VectorType vType = scatter.getVectorType();
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value ptrs =
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..59da2ebe4aae0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5340,9 +5340,9 @@ LogicalResult ScatterOp::verify() {
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
- if (valueVType.getDimSize(0) != indVType.getDimSize(0))
+ if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
- if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
+ if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 4a36294b355e7..f5c722e29420c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1654,7 +1654,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
//===----------------------------------------------------------------------===//
// vector.gather
-//
+//
// NOTE: vector.constant_mask won't lower with
// * --convert-to-llvm="filter-dialects=vector",
// hence testing here.
@@ -1719,6 +1719,40 @@ func.func @gather_with_zero_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x
// -----
+//===----------------------------------------------------------------------===//
+// vector.scatter
+//===----------------------------------------------------------------------===//
+
+// Multi-Dimensional scatters are not supported yet. Check that we do not lower
+// them.
+
+func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) {
+ %0 = arith.constant 0: index
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask
+// CHECK: vector.scatter
+
+// -----
+
+func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) {
+ %0 = arith.constant 0: index
+ // vector.constant_mask only supports 'none set' or 'all set' scalable
+ // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+ // width vectors above.
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
+ vector.scatter %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32>
+ return
+}
+
+// CHECK-LABEL: func @scatter_with_mask_scalable
+// CHECK: vector.scatter
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.interleave
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 57e348c7d5991..1b89e8eb5069b 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1484,7 +1484,7 @@ func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi
func.func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<2x16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.scatter' op operand #4 must be of ranks 1, but got 'vector<2x16xf32>'}}
+ // expected-error at +1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
vector.scatter %base[%c0][%indices], %mask, %value
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<2x16xf32>
}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 67484e06f456d..279fd3e522775 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
return
}
+// CHECK-LABEL: @gather_multi_dims
+func.func @gather_multi_dims(%base: memref<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ vector.scatter %base[%c0][%v], %mask, %0 : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
+ return %0 : vector<2x16xf32>
+}
+
// CHECK-LABEL: @gather_on_tensor
func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
return %0 : vector<16xf32>
}
-// CHECK-LABEL: @gather_multi_dims
-func.func @gather_multi_dims(%base: tensor<?xf32>, %v: vector<2x16xi32>, %mask: vector<2x16xi1>, %pass_thru: vector<2x16xf32>) -> vector<2x16xf32> {
- %c0 = arith.constant 0 : index
- // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
- %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
- return %0 : vector<2x16xf32>
-}
-
// CHECK-LABEL: @expand_and_compress
func.func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
>From 9cf516710a34db78eb82b4fdb2d8aed4f2a79938 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Thu, 20 Mar 2025 14:51:13 +0000
Subject: [PATCH 3/3] [mlir][vector] Allow lowering multi-dim scatters to LLVM
---
.../Vector/Transforms/LoweringPatterns.h | 12 ++--
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +-
.../TransformOps/VectorTransformOps.cpp | 2 +-
.../Dialect/Vector/Transforms/CMakeLists.txt | 2 +-
...ather.cpp => LowerVectorGatherScatter.cpp} | 69 ++++++++++++++++---
.../VectorToLLVM/vector-to-llvm.mlir | 6 +-
.../Dialect/Vector/TestVectorTransforms.cpp | 2 +-
7 files changed, 74 insertions(+), 21 deletions(-)
rename mlir/lib/Dialect/Vector/Transforms/{LowerVectorGather.cpp => LowerVectorGatherScatter.cpp} (80%)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 77d8b82b2bad0..528de2340f7b7 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -241,11 +241,15 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// Populate the pattern set with the following patterns:
///
-/// [FlattenGather]
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// [UnrollGather]
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+///
+/// [UnrollScatter]
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension.
+void populateVectorGatherScatterLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..dfa188bdfc5cc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,7 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
- populateVectorGatherLoweringPatterns(patterns);
+ populateVectorGatherScatterLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 20c577273d786..623b9aa83fff3 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,7 +138,7 @@ void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
void transform::ApplyLowerGatherPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorGatherLoweringPatterns(patterns);
+ vector::populateVectorGatherScatterLoweringPatterns(patterns);
}
void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..8abaa6ac527eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
- LowerVectorGather.cpp
+ LowerVectorGatherScatter.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
LowerVectorMultiReduction.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
similarity index 80%
rename from mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
index eff8ee0e9de7a..72892859df200 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGatherScatter.cpp
@@ -38,6 +38,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
+
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
@@ -81,19 +82,14 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
-
- Value indexSubVec =
- rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
- Value maskSubVec =
- rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+ Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+ Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
Value passThruSubVec =
- rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+ rewriter.create<vector::ExtractOp>(loc, passThruVec, i);
Value subGather = rewriter.create<vector::GatherOp>(
loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
passThruSubVec);
- result =
- rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+ result = rewriter.create<vector::InsertOp>(loc, subGather, result, i);
}
rewriter.replaceOp(op, result);
@@ -101,6 +97,57 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
}
};
+/// Unrolls 2 or more dimensional `vector.scatter` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.scatter %base[%c0][%v], %mask, %valueToStore : ...
+/// vector<2x3xf32>
+///
+/// ==>
+///
+/// %g0 = vector.extract %valueToStore[0] : vector<3xf32> from vector<2x3xf32>
+/// vector.scatter %base[%c0][%v0], %mask0, %g0
+/// %g1 = vector.extract %valueToStore[1] : vector<3xf32> from vector<2x3xf32>
+/// vector.scatter %base[%c0][%v0], %mask0, %g1
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d scatter ops.
+///
+/// Supports vector types with a fixed leading dimension.
+struct UnrollScatter : OpRewritePattern<vector::ScatterOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType vectorTy = op.getVectorType();
+ if (vectorTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (vectorTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op.getLoc();
+ Value indexVec = op.getIndexVec();
+ Value maskVec = op.getMask();
+ Value valueToStoreVec = op.getValueToStore();
+
+ for (int64_t i = 0, e = vectorTy.getShape().front(); i < e; ++i) {
+ Value indexSubVec = rewriter.create<vector::ExtractOp>(loc, indexVec, i);
+ Value maskSubVec = rewriter.create<vector::ExtractOp>(loc, maskVec, i);
+ Value valueToStoreSubVec =
+ rewriter.create<vector::ExtractOp>(loc, valueToStoreVec, i);
+ rewriter.create<vector::ScatterOp>(loc, op.getBase(), op.getIndices(),
+ indexSubVec, maskSubVec,
+ valueToStoreSubVec);
+ }
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
/// MemRef with updated indices that model the strided access.
///
@@ -268,9 +315,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
};
} // namespace
-void mlir::vector::populateVectorGatherLoweringPatterns(
+void mlir::vector::populateVectorGatherScatterLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<UnrollGather>(patterns.getContext(), benefit);
+ patterns.add<UnrollGather, UnrollScatter>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index f5c722e29420c..e8171e3f4853f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1734,7 +1734,8 @@ func.func @scatter_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2
}
// CHECK-LABEL: func @scatter_with_mask
-// CHECK: vector.scatter
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
// -----
@@ -1749,7 +1750,8 @@ func.func @scatter_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]x
}
// CHECK-LABEL: func @scatter_with_mask_scalable
-// CHECK: vector.scatter
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
// -----
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2cf1dde9bd1b8..df02b7675e167 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -781,7 +781,7 @@ struct TestVectorGatherLowering
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateVectorGatherLoweringPatterns(patterns);
+ populateVectorGatherScatterLoweringPatterns(patterns);
populateVectorGatherToConditionalLoadPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
More information about the Mlir-commits
mailing list