[Mlir-commits] [mlir] 0693b9e - [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (#119975)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 17 02:37:20 PST 2024
Author: Matthias Springer
Date: 2024-12-17T11:37:17+01:00
New Revision: 0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3
URL: https://github.com/llvm/llvm-project/commit/0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3
DIFF: https://github.com/llvm/llvm-project/commit/0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3.diff
LOG: [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (#119975)
Clean up `populateVectorToLLVMConversionPatterns` so that it populates
only conversion patterns. All rewrite patterns that do not lower to LLVM
should be populated into a separate greedy pattern rewrite.
The current combination of rewrite patterns and conversion patterns
triggered an edge case when merging the 1:1 and 1:N dialect conversions.
Depends on #119973.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/test/Conversion/GPUCommon/lower-vector.mlir
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2ebf38c53e3936 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Error.h"
@@ -522,6 +524,18 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
void GpuToLLVMConversionPass::runOnOperation() {
MLIRContext *context = &getContext();
+
+ // Perform progressive lowering of vector transfer operations.
+ {
+ RewritePatternSet patterns(&getContext());
+ // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+ vector::populateVectorTransferLoweringPatterns(patterns,
+ /*maxTransferRank=*/1);
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+
LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..9657f583c375bb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,17 @@ class VectorTypeCastOpConversion
/// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
/// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
- : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+ : public OpConversionPattern<vector::CreateMaskOp> {
public:
- explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
- bool enableIndexOpt)
- : OpRewritePattern<vector::CreateMaskOp>(context),
+ explicit VectorCreateMaskOpConversion(MLIRContext *context,
+ bool enableIndexOpt)
+ : OpConversionPattern<vector::CreateMaskOp>(context),
force32BitVectorIndices(enableIndexOpt) {}
- LogicalResult matchAndRewrite(vector::CreateMaskOp op,
- PatternRewriter &rewriter) const override {
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto dstType = op.getType();
if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
return failure();
@@ -1495,7 +1496,7 @@ class VectorCreateMaskOpRewritePattern
loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
/*isScalable=*/true));
auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
- op.getOperand(0));
+ adaptor.getOperands()[0]);
Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indices, bounds);
@@ -1896,16 +1897,19 @@ struct VectorScalableStepOpLowering
} // namespace
+void mlir::vector::populateVectorRankReducingFMAPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions, bool force32BitVectorIndices) {
+ // This function populates only ConversionPatterns, not RewritePatterns.
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.add<VectorFMAOpNDRewritePattern>(ctx);
- populateVectorInsertExtractStridedSliceTransforms(patterns);
- populateVectorStepLoweringPatterns(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
- patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+ patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1926,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
VectorScalableStepOpLowering>(converter);
- // Transfer ops with rank > 1 are handled by VectorToSCF.
- populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
void ConvertVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and all contraction
- // operations. Also materializes masks, applies folding and DCE.
+ // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+ // applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
force32BitVectorIndices);
+ populateVectorInsertExtractStridedSliceTransforms(patterns);
+ populateVectorStepLoweringPatterns(patterns);
+ populateVectorRankReducingFMAPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module {
- func.func @func(%arg: vector<11xf32>) {
+ func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
%cst_41 = arith.constant dense<true> : vector<11xi1>
// CHECK: vector.mask
// CHECK-SAME: vector.yield %arg0
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
- return
+ return %127 : vector<11xf32>
}
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
// CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
// CHECK-SAME: %[[ARG:.*]]: vector<4x[8]xf32>)
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK: %[[T2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
// CHECK: %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
return %0 : vector<4x4x4xf32>
}
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
// -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
return %0 : vector<4x4x[4]xf32>
}
// CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
// -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
return %0 : vector<4x4x4xindex>
}
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
// -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
return %0 : vector<4x4x[4]xindex>
}
// CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
// -----
More information about the Mlir-commits
mailing list