[Mlir-commits] [mlir] edec423 - [mlir][vector] Share enums with the transform dialect
Quentin Colombet
llvmlistbot at llvm.org
Tue Jan 17 03:15:44 PST 2023
Author: Quentin Colombet
Date: 2023-01-17T11:11:17Z
New Revision: edec423981e836304f4d724a4c805614320bf088
URL: https://github.com/llvm/llvm-project/commit/edec423981e836304f4d724a4c805614320bf088
DIFF: https://github.com/llvm/llvm-project/commit/edec423981e836304f4d724a4c805614320bf088.diff
LOG: [mlir][vector] Share enums with the transform dialect
Refactor the definition of the enums that are used in the lower_vectors
operation of the transformation dialect.
This avoid duplicating the definition of all the configurations that
this operation can trigger.
NFC
Differential Revision: https://reviews.llvm.org/D141867
Added:
mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/LLVM/transform-e2e.mlir
mlir/test/Dialect/Vector/transform-vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
index d17706ffc3ddd..f2e37413d4657 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
@@ -11,6 +11,7 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 1009c099acca6..6c98fc699467b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
@@ -31,16 +32,30 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
// TODO: evolve this to proper enums.
let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
- DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
- DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
- DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
+ DefaultValuedAttr<VectorContractLoweringAttr,
+ "vector::VectorContractLowering::OuterProduct">:$contraction_lowering,
+ DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:
+ $multireduction_lowering,
+ DefaultValuedAttr<VectorTransferSplitAttr,
+ "vector::VectorTransferSplit::LinalgCopy">:$split_transfers,
+ DefaultValuedAttr<VectorTransposeLoweringAttr,
+ "vector::VectorTransposeLowering::EltWise">:$transpose_lowering,
DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering,
DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers
);
let results = (outs PDL_Operation:$results);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat = [{
+ $target
+ oilist (
+ `contraction_lowering` `=` $contraction_lowering
+ | `multireduction_lowering` `=` $multireduction_lowering
+ | `split_transfers` `=` $split_transfers
+ | `transpose_lowering` `=` $transpose_lowering
+ )
+ attr-dict
+ }];
}
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
index 35868d1e69233..2c288fe3b77fa 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS VectorTransformsBase.td)
+mlir_tablegen(VectorTransformsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(VectorTransformsEnums.cpp.inc -gen-enum-defs)
+
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
add_public_tablegen_target(MLIRVectorTransformsIncGen)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index cd49eb1514c9d..fb2d07ebf413d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -13,6 +13,7 @@
#include <optional>
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
@@ -25,48 +26,6 @@ namespace vector {
//===----------------------------------------------------------------------===//
// Vector transformation options exposed as auxiliary structs.
//===----------------------------------------------------------------------===//
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
- /// Lower transpose into element-wise extract and inserts.
- EltWise = 0,
- /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
- /// intrinsics.
- Flat = 1,
- /// Lower 2-D transpose to `vector.shuffle`.
- Shuffle = 2,
-};
-/// Enum to control the lowering of `vector.multi_reduction` operations.
-enum class VectorMultiReductionLowering {
- /// Lower multi_reduction into outer-reduction and inner-parallel ops.
- InnerParallel = 0,
- /// Lower multi_reduction into outer-parallel and inner-reduction ops.
- InnerReduction = 1,
-};
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
- /// Progressively lower to finer grained `vector.contract` and dot-products.
- Dot = 0,
- /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
- Matmul = 1,
- /// Lower to `vector.outerproduct`.
- OuterProduct = 2,
- /// Lower contract with all reduction dimensions unrolled to 1 to a vector
- /// elementwise operations.
- ParallelArith = 3,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
- /// Do not split vector transfer operations.
- None = 0,
- /// Split using in-bounds + out-of-bounds vector.transfer operations.
- VectorTransfer = 1,
- /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
- /// operations.
- LinalgCopy = 2,
- /// Do not split vector transfer operation but instead mark it as "in-bounds".
- ForceInBounds = 3
-};
/// Structure to control the behavior of vector transform patterns.
struct VectorTransformsOptions {
/// Option to control the lowering of vector.contract.
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
new file mode 100644
index 0000000000000..fb1f2ab717687
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -0,0 +1,86 @@
+//===- VectorTransformBase.td - Vector transform ops --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef VECTOR_TRANSFORMS_BASE
+#define VECTOR_TRANSFORMS_BASE
+
+include "mlir/IR/EnumAttr.td"
+
+// Lower transpose into element-wise extract and inserts.
+def VectorTransposeLowering_Elementwise:
+ I32EnumAttrCase<"EltWise", 0, "eltwise">;
+// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+// intrinsics.
+def VectorTransposeLowering_FlatTranspose:
+ I32EnumAttrCase<"Flat", 1, "flat_transpose">;
+// Lower 2-D transpose to `vector.shuffle`.
+def VectorTransposeLowering_Shuffle:
+ I32EnumAttrCase<"Shuffle", 2, "shuffle">;
+def VectorTransposeLoweringAttr : I32EnumAttr<
+ "VectorTransposeLowering",
+ "control the lowering of `vector.transpose` operations.",
+ [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
+ VectorTransposeLowering_Shuffle]> {
+ let cppNamespace = "::mlir::vector";
+}
+
+// Lower multi_reduction into outer-reduction and inner-parallel ops.
+def VectorMultiReductionLowering_InnerParallel:
+ I32EnumAttrCase<"InnerParallel", 0, "innerparallel">;
+// Lower multi_reduction into outer-parallel and inner-reduction ops.
+def VectorMultiReductionLowering_InnerReduction:
+ I32EnumAttrCase<"InnerReduction", 1, "innerreduction">;
+def VectorMultiReductionLoweringAttr: I32EnumAttr<
+ "VectorMultiReductionLowering",
+ "control the lowering of `vector.multi_reduction`.",
+ [VectorMultiReductionLowering_InnerParallel,
+ VectorMultiReductionLowering_InnerReduction]> {
+ let cppNamespace = "::mlir::vector";
+}
+
+// Progressively lower to finer grained `vector.contract` and dot-products.
+def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">;
+// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+def VectorContractLowering_Matmul:
+ I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">;
+// Lower to `vector.outerproduct`.
+def VectorContractLowering_OuterProduct:
+ I32EnumAttrCase<"OuterProduct", 2, "outerproduct">;
+// Lower contract with all reduction dimensions unrolled to 1 to a vector
+// elementwise operations.
+def VectorContractLowering_ParallelArith:
+ I32EnumAttrCase<"ParallelArith", 3, "parallelarith">;
+def VectorContractLoweringAttr: I32EnumAttr<
+ "VectorContractLowering",
+ "control the lowering of `vector.contract` operations.",
+ [VectorContractLowering_Dot, VectorContractLowering_Matmul,
+ VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> {
+ let cppNamespace = "::mlir::vector";
+}
+
+// Do not split vector transfer operations.
+def VectorTransferSplit_None: I32EnumAttrCase<"None", 0, "none">;
+// Split using in-bounds + out-of-bounds vector.transfer operations.
+def VectorTransferSplit_VectorTransfer:
+ I32EnumAttrCase<"VectorTransfer", 1, "vector-transfer">;
+// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+// operations.
+def VectorTransferSplit_LinalgCopy:
+ I32EnumAttrCase<"LinalgCopy", 2, "linalg-copy">;
+// Do not split vector transfer operation but instead mark it as "in-bounds".
+def VectorTransferSplit_ForceInBounds:
+ I32EnumAttrCase<"ForceInBounds", 3, "force-in-bounds">;
+def VectorTransferSplitAttr: I32EnumAttr<
+ "VectorTransferSplit",
+ "control the splitting of `vector.transfer` operations into in-bounds"
+ " and out-of-bounds variants.",
+ [VectorTransferSplit_None, VectorTransferSplit_VectorTransfer,
+ VectorTransferSplit_LinalgCopy, VectorTransferSplit_ForceInBounds]> {
+ let cppNamespace = "::mlir::vector";
+}
+#endif
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 79d1ff7040d45..60996b9add614 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -53,32 +53,12 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
vector::VectorTransposeLowering vectorTransposeLowering =
- llvm::StringSwitch<vector::VectorTransposeLowering>(
- getTransposeLowering())
- .Case("eltwise", vector::VectorTransposeLowering::EltWise)
- .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
- .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
- .Default(vector::VectorTransposeLowering::EltWise);
+ getTransposeLowering();
vector::VectorMultiReductionLowering vectorMultiReductionLowering =
- llvm::StringSwitch<vector::VectorMultiReductionLowering>(
- getMultireductionLowering())
- .Case("innerreduction",
- vector::VectorMultiReductionLowering::InnerReduction)
- .Default(vector::VectorMultiReductionLowering::InnerParallel);
+ getMultireductionLowering();
vector::VectorContractLowering vectorContractLowering =
- llvm::StringSwitch<vector::VectorContractLowering>(
- getContractionLowering())
- .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
- .Case("dot", vector::VectorContractLowering::Dot)
- .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
- .Default(vector::VectorContractLowering::OuterProduct);
- vector::VectorTransferSplit vectorTransferSplit =
- llvm::StringSwitch<vector::VectorTransferSplit>(getSplitTransfers())
- .Case("none", vector::VectorTransferSplit::None)
- .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
- .Case("vector-transfers",
- vector::VectorTransferSplit::VectorTransfer)
- .Default(vector::VectorTransferSplit::None);
+ getContractionLowering();
+ vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers();
vector::VectorTransformsOptions vectorTransformOptions;
vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index e67f147399e59..504ea7afe8301 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -3040,3 +3040,9 @@ void mlir::vector::populateVectorScanLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd enum attribute definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 2aeda92dfd528..f899a81d1a5d0 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -21,5 +21,5 @@ transform.sequence failures(propagate) {
transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
{bufferize_function_boundaries = true}
%func = transform.structured.match ops{["func.func"]} in %module_op
- transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
+ transform.vector.lower_vectors %func multireduction_lowering = "innerreduction"
}
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 25c26b63c6107..a753b229576a7 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -22,5 +22,5 @@ transform.sequence failures(propagate) {
transform.bufferization.one_shot_bufferize %module_op
%func = transform.structured.match ops{["func.func"]} in %module_op
- transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
+ transform.vector.lower_vectors %func multireduction_lowering = "innerreduction"
}
More information about the Mlir-commits
mailing list