[Mlir-commits] [mlir] 176a0ea - [mlr][Linalg] NFC - Add option to hook vector.multi_reduction lowering to strategies.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Oct 25 04:35:49 PDT 2021


Author: Nicolas Vasilache
Date: 2021-10-25T11:31:39Z
New Revision: 176a0ea535d4403b3a8482153f366e55141c7e50

URL: https://github.com/llvm/llvm-project/commit/176a0ea535d4403b3a8482153f366e55141c7e50
DIFF: https://github.com/llvm/llvm-project/commit/176a0ea535d4403b3a8482153f366e55141c7e50.diff

LOG: [mlr][Linalg] NFC - Add option to hook vector.multi_reduction lowering to strategies.

Differential Revision: https://reviews.llvm.org/D112414

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
    mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4c2f004b5361e..640e1221aeb53 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -899,6 +899,12 @@ struct LinalgVectorLoweringOptions {
     contractionLowering = val;
     return *this;
   }
+  /// Enable lowering of vector.multi_reduce.
+  bool multiReductionLowering = false;
+  LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
+    multiReductionLowering = val;
+    return *this;
+  }
   /// Enable lowering of vector.transfer to scf.
   bool transferToSCFConversion = false;
   LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 5a5eadf6ad7d1..dd56cd1ea1926 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -40,6 +40,76 @@ namespace detail {
 struct BitmaskEnumStorage;
 } // namespace detail
 
+/// Enum to control the lowering of `vector.contract` operations.
+enum class VectorContractLowering {
+  /// Progressively lower to finer grained `vector.contract` and dot-products.
+  Dot = 0,
+  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+  Matmul = 1,
+  /// Lower to `vector.outerproduct`.
+  OuterProduct = 2,
+};
+/// Enum to control the lowering of `vector.multi_reduction` operations.
+enum class VectorMultiReductionLowering {
+  /// Lower multi_reduction into outer-reduction and inner-parallel ops.
+  InnerParallel = 0,
+  /// Lower multi_reduction into outer-parallel and inner-reduction ops.
+  InnerReduction = 1,
+};
+/// Enum to control the lowering of `vector.transpose` operations.
+enum class VectorTransposeLowering {
+  /// Lower transpose into element-wise extract and inserts.
+  EltWise = 0,
+  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+  /// intrinsics.
+  Flat = 1,
+};
+/// Enum to control the splitting of `vector.transfer` operations into
+/// in-bounds and out-of-bounds variants.
+enum class VectorTransferSplit {
+  /// Do not split vector transfer operations.
+  None = 0,
+  /// Split using in-bounds + out-of-bounds vector.transfer operations.
+  VectorTransfer = 1,
+  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+  /// operations.
+  LinalgCopy = 2,
+  /// Do not split vector transfer operation but instead mark it as "in-bounds".
+  ForceInBounds = 3
+};
+/// Structure to control the behavior of vector transform patterns.
+struct VectorTransformsOptions {
+  /// Option to control the lowering of vector.contract.
+  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
+  VectorTransformsOptions &
+  setVectorTransformsOptions(VectorContractLowering opt) {
+    vectorContractLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.multi_reduction.
+  VectorMultiReductionLowering vectorMultiReductionLowering =
+      VectorMultiReductionLowering::InnerParallel;
+  VectorTransformsOptions &
+  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+    vectorMultiReductionLowering = opt;
+    return *this;
+  }
+  /// Option to control the lowering of vector.transpose.
+  VectorTransposeLowering vectorTransposeLowering =
+      VectorTransposeLowering::EltWise;
+  VectorTransformsOptions &
+  setVectorTransposeLowering(VectorTransposeLowering opt) {
+    vectorTransposeLowering = opt;
+    return *this;
+  }
+  /// Option to control the splitting of vector transfers.
+  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
+  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+    vectorTransferSplit = opt;
+    return *this;
+  }
+};
+
 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
 /// semantics of the `vector.broadcast` op.
 enum class BroadcastableToResult {
@@ -114,7 +184,9 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
 /// the other patterns can kick in, thus fully exiting out of the
 /// vector.multi_reduction abstraction.
 void populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, bool useInnerDimsForReduction = false);
+    RewritePatternSet &patterns,
+    VectorMultiReductionLowering options =
+        vector::VectorMultiReductionLowering::InnerParallel);
 
 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
 /// chain.
@@ -136,61 +208,6 @@ class CombiningKindAttr
   static Attribute parse(DialectAsmParser &parser);
 };
 
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
-  /// Progressively lower to finer grained `vector.contract` and dot-products.
-  Dot = 0,
-  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
-  Matmul = 1,
-  /// Lower to `vector.outerproduct`.
-  OuterProduct = 2,
-};
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
-  /// Lower transpose into element-wise extract and inserts.
-  EltWise = 0,
-  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
-  /// intrinsics.
-  Flat = 1,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
-  /// Do not split vector transfer operations.
-  None = 0,
-  /// Split using in-bounds + out-of-bounds vector.transfer operations.
-  VectorTransfer = 1,
-  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
-  /// operations.
-  LinalgCopy = 2,
-  /// Do not split vector transfer operation but instead mark it as "in-bounds".
-  ForceInBounds = 3
-};
-/// Structure to control the behavior of vector transform patterns.
-struct VectorTransformsOptions {
-  /// Option to control the lowering of vector.contract.
-  VectorContractLowering vectorContractLowering = VectorContractLowering::Dot;
-  VectorTransformsOptions &
-  setVectorTransformsOptions(VectorContractLowering opt) {
-    vectorContractLowering = opt;
-    return *this;
-  }
-  /// Option to control the lowering of vector.transpose.
-  VectorTransposeLowering vectorTransposeLowering =
-      VectorTransposeLowering::EltWise;
-  VectorTransformsOptions &
-  setVectorTransposeLowering(VectorTransposeLowering opt) {
-    vectorTransposeLowering = opt;
-    return *this;
-  }
-  /// Option to control the splitting of vector transfers.
-  VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None;
-  VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
-    vectorTransferSplit = opt;
-    return *this;
-  }
-};
-
 /// Collects patterns to progressively lower vector.broadcast ops on high-D
 /// vectors to low-D vector ops.
 void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 8cc06d486371b..b7506eb91aa9a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -263,6 +263,7 @@ struct LinalgStrategyLowerVectorsPass
 
     MLIRContext *context = funcOp.getContext();
     RewritePatternSet patterns(context);
+    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
     if (options.transferLowering) {
       vector::populateVectorTransferLoweringPatterns(patterns,
                                                      options.maxTransferRank);
@@ -277,6 +278,11 @@ struct LinalgStrategyLowerVectorsPass
           options.vectorTransformOptions, context);
       vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
     }
+    if (options.multiReductionLowering) {
+      vector::populateVectorMultiReductionLoweringPatterns(
+          patterns,
+          options.vectorTransformOptions.vectorMultiReductionLowering);
+    }
     if (options.transferToSCFConversion) {
       populateVectorToSCFConversionPatterns(patterns,
                                             options.vectorTransferToSCFOptions);

diff  --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index f738364529776..67d0db4d2cd45 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -35,10 +35,11 @@ class InnerOuterDimReductionConversion
 public:
   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
 
-  explicit InnerOuterDimReductionConversion(MLIRContext *context,
-                                            bool useInnerDimsForReduction)
+  explicit InnerOuterDimReductionConversion(
+      MLIRContext *context, vector::VectorMultiReductionLowering options)
       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
-        useInnerDimsForReduction(useInnerDimsForReduction) {}
+        useInnerDimsForReduction(
+            options == vector::VectorMultiReductionLowering::InnerReduction) {}
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
@@ -103,10 +104,11 @@ class ReduceMultiDimReductionRank
 public:
   using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
 
-  explicit ReduceMultiDimReductionRank(MLIRContext *context,
-                                       bool useInnerDimsForReduction)
+  explicit ReduceMultiDimReductionRank(
+      MLIRContext *context, vector::VectorMultiReductionLowering options)
       : mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
-        useInnerDimsForReduction(useInnerDimsForReduction) {}
+        useInnerDimsForReduction(
+            options == vector::VectorMultiReductionLowering::InnerReduction) {}
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
@@ -398,11 +400,11 @@ struct OneDimMultiReductionToTwoDim
 };
 
 void mlir::vector::populateVectorMultiReductionLoweringPatterns(
-    RewritePatternSet &patterns, bool useInnerDimsForReduction) {
-  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank,
-               OneDimMultiReductionToTwoDim>(patterns.getContext(),
-                                             useInnerDimsForReduction);
-  if (useInnerDimsForReduction)
+    RewritePatternSet &patterns, VectorMultiReductionLowering options) {
+  patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
+      patterns.getContext(), options);
+  patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
+  if (options == VectorMultiReductionLowering ::InnerReduction)
     patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
   else
     patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());

diff  --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
index 45c985bbef1ea..f73cd6c0a1eb0 100644
--- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
@@ -94,7 +94,8 @@ void TestConvVectorization::runOnOperation() {
   //===--------------------------------------------------------------------===//
 
   VectorTransformsOptions vectorTransformOptions{
-      VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
+      VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel,
+      VectorTransposeLowering::EltWise};
 
   RewritePatternSet vectorTransferPatterns(context);
   // Pattern is not applied because rank-reducing vector transfer is not yet

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d5538b09dadda..b95c45d21633a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -159,11 +159,14 @@ struct TestVectorContractionConversion
     VectorContractLowering contractLowering = VectorContractLowering::Dot;
     if (lowerToFlatMatrix)
       contractLowering = VectorContractLowering::Matmul;
+    VectorMultiReductionLowering vectorMultiReductionLowering =
+        VectorMultiReductionLowering::InnerParallel;
     VectorTransposeLowering transposeLowering =
         VectorTransposeLowering::EltWise;
     if (lowerToFlatTranspose)
       transposeLowering = VectorTransposeLowering::Flat;
-    VectorTransformsOptions options{contractLowering, transposeLowering};
+    VectorTransformsOptions options{
+        contractLowering, vectorMultiReductionLowering, transposeLowering};
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, options);
     populateVectorMaskOpLoweringPatterns(patterns);
@@ -461,7 +464,10 @@ struct TestVectorMultiReductionLoweringPatterns
       llvm::cl::init(false)};
   void runOnFunction() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions);
+    populateVectorMultiReductionLoweringPatterns(
+        patterns, useOuterReductions
+                      ? vector::VectorMultiReductionLowering::InnerParallel
+                      : vector::VectorMultiReductionLowering::InnerReduction);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 };


        


More information about the Mlir-commits mailing list