[Mlir-commits] [mlir] e91a5ce - [mlir][vector] Add a custom builder for LowerVectorsOp
Quentin Colombet
llvmlistbot at llvm.org
Thu Jan 19 03:15:41 PST 2023
Author: Quentin Colombet
Date: 2023-01-19T11:01:27Z
New Revision: e91a5ce278381771ea6d4e6d602e1e486de655f9
URL: https://github.com/llvm/llvm-project/commit/e91a5ce278381771ea6d4e6d602e1e486de655f9
DIFF: https://github.com/llvm/llvm-project/commit/e91a5ce278381771ea6d4e6d602e1e486de655f9.diff
LOG: [mlir][vector] Add a custom builder for LowerVectorsOp
The `lower_vectors` operation of the transform dialect takes a lot of
arguments to build.
In order to make C++ code easier to work with when using this
instruction, introduce a new structure, named `LowerVectorsOptions`, that
aggregates all the options that are used to build this instruction.
This allows to use patterns like:
```
LowerVectorsOptions opts;
opts.setOptZ(...)
.setOptY(...)...;
builder.create<LowerVectorsOp>(target, opts);
```
Instead of having to pass all N options directly to the builder and set
them in the right order.
NFC
Differential Revision: https://reviews.llvm.org/D141923
Added:
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
index f2e37413d465..032ce2da8a89 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
@@ -11,12 +11,14 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
namespace vector {
class VectorOp;
+struct LowerVectorsOptions;
} // namespace vector
} // namespace mlir
@@ -32,6 +34,53 @@ class DialectRegistry;
namespace vector {
void registerTransformDialectExtension(DialectRegistry ®istry);
+
+/// Helper structure used to hold the
diff erent options of LowerVectorsOp.
+struct LowerVectorsOptions : public VectorTransformsOptions {
+ // Have the default values match the LowerVectorsOp values in the td file.
+ LowerVectorsOptions() : VectorTransformsOptions() {
+ setVectorTransformsOptions(VectorContractLowering::OuterProduct);
+ setVectorMultiReductionLowering(
+ VectorMultiReductionLowering::InnerParallel);
+ setVectorTransposeLowering(VectorTransposeLowering::EltWise);
+ setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
+ }
+
+ /// Duplicate the base API of VectorTransformsOptions but return the
+ /// LowerVectorsOptions type. This allows to really set up the
diff erent
+ /// options in any order via chained setXXX calls. @{
+ LowerVectorsOptions &setVectorTransformsOptions(VectorContractLowering opt) {
+ VectorTransformsOptions::setVectorTransformsOptions(opt);
+ return *this;
+ }
+
+ LowerVectorsOptions &
+ setVectorMultiReductionLowering(VectorMultiReductionLowering opt) {
+ VectorTransformsOptions::setVectorMultiReductionLowering(opt);
+ return *this;
+ }
+ LowerVectorsOptions &setVectorTransposeLowering(VectorTransposeLowering opt) {
+ VectorTransformsOptions::setVectorTransposeLowering(opt);
+ return *this;
+ }
+ LowerVectorsOptions &setVectorTransferSplit(VectorTransferSplit opt) {
+ VectorTransformsOptions::setVectorTransferSplit(opt);
+ return *this;
+ }
+ /// @}
+
+ bool transposeAVX2Lowering = false;
+ LowerVectorsOptions &setTransposeAVX2Lowering(bool opt) {
+ transposeAVX2Lowering = opt;
+ return *this;
+ }
+
+ bool unrollVectorTransfers = true;
+ LowerVectorsOptions &setUnrollVectorTransfers(bool opt) {
+ unrollVectorTransfers = opt;
+ return *this;
+ }
+};
} // namespace vector
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 6c98fc699467..060e6bcef5cf 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -46,6 +46,18 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
);
let results = (outs PDL_Operation:$results);
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "Value":$target,
+ "const vector::LowerVectorsOptions &":$options), [{
+ return build($_builder, $_state, resultType, target,
+ options.vectorContractLowering,
+ options.vectorMultiReductionLowering, options.vectorTransferSplit,
+ options.vectorTransposeLowering, options.transposeAVX2Lowering,
+ options.unrollVectorTransfers);
+ }]
+ >
+ ];
+
let assemblyFormat = [{
$target
oilist (
More information about the Mlir-commits
mailing list