[Mlir-commits] [mlir] cf0efb3 - [mlir][vector] Decouple unrolling gather and gather to llvm lowering (#132206)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 24 05:25:21 PDT 2025
Author: Kunwar Grover
Date: 2025-03-24T12:25:17Z
New Revision: cf0efb31880dab5f5b2f20bda6634c68a42d6908
URL: https://github.com/llvm/llvm-project/commit/cf0efb31880dab5f5b2f20bda6634c68a42d6908
DIFF: https://github.com/llvm/llvm-project/commit/cf0efb31880dab5f5b2f20bda6634c68a42d6908.diff
LOG: [mlir][vector] Decouple unrolling gather and gather to llvm lowering (#132206)
This patch decouples unrolling vector.gather and lowering vector.gather
to llvm.masked.gather.
This is consistent with how vector.load, vector.store,
vector.maskedload, vector.maskedstore lower to LLVM.
Some interesting test changes from this patch:
- 2D vector.gather lowering to llvm tests are deleted. This is
consistent with other memory load/store ops.
- There are still tests for 2D vector.gather, but the constant mask for
these test is modified. This is because with the updated lowering, one
of the unrolled vector.gather disappears because it is masked off (also
demonstrating why this is a better lowering path)
Overall, this makes vector.gather take the same consistent path for
lowering to LLVM as other load/store ops.
Discourse Discussion:
https://discourse.llvm.org/t/rfc-improving-gather-codegen-for-vector-dialect/85011/13
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 601a65333d026..14cff4ff893b5 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -241,16 +241,20 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
/// Populate the pattern set with the following patterns:
///
-/// [FlattenGather]
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// [UnrollGather]
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
+/// Populate the pattern set with the following patterns:
///
/// [Gather1DToConditionalLoads]
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
-void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populates instances of `MaskOpRewritePattern` to lower masked operations
/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 94efec61a466c..357152eba8003 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -269,49 +269,30 @@ class VectorGatherOpConversion
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return failure();
- auto loc = gather->getLoc();
+ VectorType vType = gather.getVectorType();
+ if (vType.getRank() > 1)
+ return failure();
+
+ Location loc = gather->getLoc();
// Resolve alignment.
unsigned align;
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
return failure();
+ // Resolve address.
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
adaptor.getIndices(), rewriter);
Value base = adaptor.getBase();
+ Value ptrs =
+ getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
+ base, ptr, adaptor.getIndexVec(), vType);
- auto llvmNDVectorTy = adaptor.getIndexVec().getType();
- // Handle the simple case of 1-D vector.
- if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) {
- auto vType = gather.getVectorType();
- // Resolve address.
- Value ptrs =
- getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
- base, ptr, adaptor.getIndexVec(), vType);
- // Replace with the gather intrinsic.
- rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
- gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
- adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
- return success();
- }
-
- const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
- auto callback = [align, memRefType, base, ptr, loc, &rewriter,
- &typeConverter](Type llvm1DVectorTy,
- ValueRange vectorOperands) {
- // Resolve address.
- Value ptrs = getIndexedPtrs(
- rewriter, loc, typeConverter, memRefType, base, ptr,
- /*index=*/vectorOperands[0], cast<VectorType>(llvm1DVectorTy));
- // Create the gather intrinsic.
- return rewriter.create<LLVM::masked_gather>(
- loc, llvm1DVectorTy, ptrs, /*mask=*/vectorOperands[1],
- /*passThru=*/vectorOperands[2], rewriter.getI32IntegerAttr(align));
- };
- SmallVector<Value> vectorOperands = {
- adaptor.getIndexVec(), adaptor.getMask(), adaptor.getPassThru()};
- return LLVM::detail::handleMultidimensionalVectors(
- gather, vectorOperands, *getTypeConverter(), callback, rewriter);
+ // Replace with the gather intrinsic.
+ rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
+ gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),
+ adaptor.getPassThru(), rewriter.getI32IntegerAttr(align));
+ return success();
}
};
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index eb1555df5d574..7082b92c95d1d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -81,6 +81,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
+ populateVectorGatherLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..3000204c8ce17 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -38,7 +38,7 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
/// outermost dimension. For example:
/// ```
/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
@@ -56,14 +56,14 @@ namespace {
/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
///
/// Supports vector types with a fixed leading dimension.
-struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+struct UnrollGather : OpRewritePattern<vector::GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
VectorType resultTy = op.getType();
if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already flat");
+ return rewriter.notifyMatchFailure(op, "already 1-D");
// Unrolling doesn't take vscale into account. Pattern is disabled for
// vectors with leading scalable dim(s).
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...)
+/// : memref<100xf32, strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -269,6 +270,11 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, RemoveStrideFromGatherSource,
- Gather1DToConditionalLoads>(patterns.getContext(), benefit);
+ patterns.add<UnrollGather>(patterns.getContext(), benefit);
+}
+
+void mlir::vector::populateVectorGatherToConditionalLoadPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index c3f06dd4d5dd1..44b4a25a051f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2074,52 +2074,6 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
// -----
-func.func @gather_2d_from_1d(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
- return %1 : vector<2x3xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<3xf32>>
-
-// -----
-
-func.func @gather_2d_from_1d_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xi1>, %arg3: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
- return %1 : vector<2x[3]xf32>
-}
-
-// CHECK-LABEL: func @gather_2d_from_1d_scalable
-// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[I0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S0:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P0:.*]] = llvm.getelementptr %[[B]][%[[I0]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %[[P0]], %[[M0]], %[[S0]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G0]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[I1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi32>>
-// CHECK: %[[M1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xi1>>
-// CHECK: %[[S1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-// CHECK: %[[P1:.*]] = llvm.getelementptr %[[B]][%[[I1]]] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %[[P1]], %[[M1]], %[[S1]] {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %{{.*}} = llvm.insertvalue %[[G1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
-
-// -----
-
func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 1ab28b9df2d19..5404fdda033ee 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1663,7 +1663,7 @@ func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = arith.constant 0: index
- %1 = vector.constant_mask [1, 2] : vector<2x3xi1>
+ %1 = vector.constant_mask [2, 2] : vector<2x3xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x3xi32>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %2 : vector<2x3xf32>
}
@@ -1677,9 +1677,9 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
%0 = arith.constant 0: index
// vector.constant_mask only supports 'none set' or 'all set' scalable
- // dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
+ // dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
// width vectors above.
- %1 = vector.constant_mask [1, 3] : vector<2x[3]xi1>
+ %1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>
return %2 : vector<2x[3]xf32>
}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index c9bb34b692dd0..a54ae816570a8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -782,6 +782,7 @@ struct TestVectorGatherLowering
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorGatherToConditionalLoadPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list