[Mlir-commits] [mlir] e91a5ce - [mlir][vector] Add a custom builder for LowerVectorsOp

Quentin Colombet llvmlistbot at llvm.org
Thu Jan 19 03:15:41 PST 2023


Author: Quentin Colombet
Date: 2023-01-19T11:01:27Z
New Revision: e91a5ce278381771ea6d4e6d602e1e486de655f9

URL: https://github.com/llvm/llvm-project/commit/e91a5ce278381771ea6d4e6d602e1e486de655f9
DIFF: https://github.com/llvm/llvm-project/commit/e91a5ce278381771ea6d4e6d602e1e486de655f9.diff

LOG: [mlir][vector] Add a custom builder for LowerVectorsOp

The `lower_vectors` operation of the transform dialect takes a lot of
arguments to build.
In order to make C++ code easier to work with when using this
instruction, introduce a new structure, named `LowerVectorsOptions`, that
aggregates all the options that are used to build this instruction.

This allows to use patterns like:
```
LowerVectorsOptions opts;
opts.setOptZ(...)
  .setOptY(...)...;
builder.create<LowerVectorsOp>(target, opts);
```

Instead of having to pass all N options directly to the builder and set
them in the right order.

NFC

Differential Revision: https://reviews.llvm.org/D141923

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
index f2e37413d465..032ce2da8a89 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
@@ -11,12 +11,14 @@
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/OpImplementation.h"
 
 namespace mlir {
 namespace vector {
 class VectorOp;
+struct LowerVectorsOptions;
 } // namespace vector
 } // namespace mlir
 
@@ -32,6 +34,53 @@ class DialectRegistry;
 
 namespace vector {
 void registerTransformDialectExtension(DialectRegistry &registry);
+
+/// Helper structure used to hold the 
diff erent options of LowerVectorsOp.
+struct LowerVectorsOptions : public VectorTransformsOptions {
+  // Have the default values match the LowerVectorsOp values in the td file.
+  LowerVectorsOptions() : VectorTransformsOptions() {
+    setVectorTransformsOptions(VectorContractLowering::OuterProduct);
+    setVectorMultiReductionLowering(
+        VectorMultiReductionLowering::InnerParallel);
+    setVectorTransposeLowering(VectorTransposeLowering::EltWise);
+    setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
+  }
+
+  /// Duplicate the base API of VectorTransformsOptions but return the
+  /// LowerVectorsOptions type. This allows to really set up the 
diff erent
+  /// options in any order via chained setXXX calls. @{
+  LowerVectorsOptions &setVectorTransformsOptions(VectorContractLowering opt) {
+    VectorTransformsOptions::setVectorTransformsOptions(opt);
+    return *this;
+  }
+
+  LowerVectorsOptions &
+  setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+    VectorTransformsOptions::setVectorMultiReductionLowering(opt);
+    return *this;
+  }
+  LowerVectorsOptions &setVectorTransposeLowering(VectorTransposeLowering opt) {
+    VectorTransformsOptions::setVectorTransposeLowering(opt);
+    return *this;
+  }
+  LowerVectorsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+    VectorTransformsOptions::setVectorTransferSplit(opt);
+    return *this;
+  }
+  /// @}
+
+  bool transposeAVX2Lowering = false;
+  LowerVectorsOptions &setTransposeAVX2Lowering(bool opt) {
+    transposeAVX2Lowering = opt;
+    return *this;
+  }
+
+  bool unrollVectorTransfers = true;
+  LowerVectorsOptions &setUnrollVectorTransfers(bool opt) {
+    unrollVectorTransfers = opt;
+    return *this;
+  }
+};
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 6c98fc699467..060e6bcef5cf 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -46,6 +46,18 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
   );
   let results = (outs PDL_Operation:$results);
 
+  let builders = [
+    OpBuilder<(ins "Type":$resultType, "Value":$target,
+      "const vector::LowerVectorsOptions &":$options), [{
+        return build($_builder, $_state, resultType, target,
+          options.vectorContractLowering,
+          options.vectorMultiReductionLowering, options.vectorTransferSplit,
+          options.vectorTransposeLowering, options.transposeAVX2Lowering,
+          options.unrollVectorTransfers);
+      }]
+    >
+  ];
+
   let assemblyFormat = [{
     $target
     oilist (


        


More information about the Mlir-commits mailing list