[Mlir-commits] [mlir] 0693b9e - [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (#119975)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 17 02:37:20 PST 2024


Author: Matthias Springer
Date: 2024-12-17T11:37:17+01:00
New Revision: 0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3

URL: https://github.com/llvm/llvm-project/commit/0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3
DIFF: https://github.com/llvm/llvm-project/commit/0693b9e9ccdec5f09a3080b1bec73f5004a8dfa3.diff

LOG: [mlir][Vector] Clean up `populateVectorToLLVMConversionPatterns` (#119975)

Clean up `populateVectorToLLVMConversionPatterns` so that it populates
only conversion patterns. All rewrite patterns that do not lower to LLVM
should be populated into a separate greedy pattern rewrite.

The current combination of rewrite patterns and conversion patterns
triggered an edge case when merging the 1:1 and 1:N dialect conversions.

Depends on #119973.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
    mlir/test/Conversion/GPUCommon/lower-vector.mlir
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2ebf38c53e3936 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
@@ -522,6 +524,18 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
+
+  // Perform progressive lowering of vector transfer operations.
+  {
+    RewritePatternSet patterns(&getContext());
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    vector::populateVectorTransferLoweringPatterns(patterns,
+                                                   /*maxTransferRank=*/1);
+    if (failed(
+            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..9657f583c375bb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,17 @@ class VectorTypeCastOpConversion
 
 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
 /// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
-    : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+    : public OpConversionPattern<vector::CreateMaskOp> {
 public:
-  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
-                                            bool enableIndexOpt)
-      : OpRewritePattern<vector::CreateMaskOp>(context),
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
+                                        bool enableIndexOpt)
+      : OpConversionPattern<vector::CreateMaskOp>(context),
         force32BitVectorIndices(enableIndexOpt) {}
 
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
       return failure();
@@ -1495,7 +1496,7 @@ class VectorCreateMaskOpRewritePattern
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
-                                                 op.getOperand(0));
+                                                 adaptor.getOperands()[0]);
     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
@@ -1896,16 +1897,19 @@ struct VectorScalableStepOpLowering
 
 } // namespace
 
+void mlir::vector::populateVectorRankReducingFMAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions, bool force32BitVectorIndices) {
+  // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
-  populateVectorInsertExtractStridedSliceTransforms(patterns);
-  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1926,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
 
 void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and all contraction
-  // operations. Also materializes masks, applies folding and DCE.
+  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+  // applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
                                               force32BitVectorIndices);
+    populateVectorInsertExtractStridedSliceTransforms(patterns);
+    populateVectorStepLoweringPatterns(patterns);
+    populateVectorRankReducingFMAPattern(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 

diff  --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
-  func.func @func(%arg: vector<11xf32>) {
+  func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
     // CHECK: vector.mask
     // CHECK-SAME: vector.yield %arg0
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
-    return
+    return %127 : vector<11xf32>
   }
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
 //  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
 // CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
 // CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
 // CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
   return %0 : vector<4x4x4xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 
 // -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
   return %0 : vector<4x4x[4]xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 
 // -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
   return %0 : vector<4x4x4xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 
 // -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
   return %0 : vector<4x4x[4]xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 
 // -----


        


More information about the Mlir-commits mailing list