[Mlir-commits] [mlir] edec423 - [mlir][vector] Share enums with the transform dialect

Quentin Colombet llvmlistbot at llvm.org
Tue Jan 17 03:15:44 PST 2023


Author: Quentin Colombet
Date: 2023-01-17T11:11:17Z
New Revision: edec423981e836304f4d724a4c805614320bf088

URL: https://github.com/llvm/llvm-project/commit/edec423981e836304f4d724a4c805614320bf088
DIFF: https://github.com/llvm/llvm-project/commit/edec423981e836304f4d724a4c805614320bf088.diff

LOG: [mlir][vector] Share enums with the transform dialect

Refactor the definition of the enums that are used in the lower_vectors
operation of the transformation dialect.
This avoid duplicating the definition of all the configurations that
this operation can trigger.

NFC

Differential Revision: https://reviews.llvm.org/D141867

Added: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/LLVM/transform-e2e.mlir
    mlir/test/Dialect/Vector/transform-vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
index d17706ffc3ddd..f2e37413d4657 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/IR/OpImplementation.h"
 
 namespace mlir {

diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 1009c099acca6..6c98fc699467b 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
@@ -31,16 +32,30 @@ def LowerVectorsOp : Op<Transform_Dialect, "vector.lower_vectors",
 
   // TODO: evolve this to proper enums.
   let arguments = (ins PDL_Operation:$target,
-     DefaultValuedAttr<StrAttr, "\"outerproduct\"">:$contraction_lowering,
-     DefaultValuedAttr<StrAttr, "\"innerparallel\"">:$multireduction_lowering,
-     DefaultValuedAttr<StrAttr, "\"linalg-copy\"">:$split_transfers,
-     DefaultValuedAttr<StrAttr, "\"eltwise\"">:$transpose_lowering,
+     DefaultValuedAttr<VectorContractLoweringAttr,
+       "vector::VectorContractLowering::OuterProduct">:$contraction_lowering,
+     DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+       "vector::VectorMultiReductionLowering::InnerParallel">:
+         $multireduction_lowering,
+     DefaultValuedAttr<VectorTransferSplitAttr,
+       "vector::VectorTransferSplit::LinalgCopy">:$split_transfers,
+     DefaultValuedAttr<VectorTransposeLoweringAttr,
+       "vector::VectorTransposeLowering::EltWise">:$transpose_lowering,
      DefaultValuedAttr<BoolAttr, "false">:$transpose_avx2_lowering,
      DefaultValuedAttr<BoolAttr, "true">:$unroll_vector_transfers
   );
   let results = (outs PDL_Operation:$results);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{
+    $target
+    oilist (
+      `contraction_lowering` `=` $contraction_lowering
+      | `multireduction_lowering` `=` $multireduction_lowering
+      | `split_transfers` `=` $split_transfers
+      | `transpose_lowering` `=` $transpose_lowering
+    )
+    attr-dict
+  }];
 }
 
 #endif // VECTOR_TRANSFORM_OPS

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
index 35868d1e69233..2c288fe3b77fa 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,3 +1,7 @@
+set(LLVM_TARGET_DEFINITIONS VectorTransformsBase.td)
+mlir_tablegen(VectorTransformsEnums.h.inc -gen-enum-decls)
+mlir_tablegen(VectorTransformsEnums.cpp.inc -gen-enum-defs)
+
 set(LLVM_TARGET_DEFINITIONS Passes.td)
 mlir_tablegen(Passes.h.inc -gen-pass-decls -name Vector)
 add_public_tablegen_target(MLIRVectorTransformsIncGen)

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index cd49eb1514c9d..fb2d07ebf413d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -13,6 +13,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -25,48 +26,6 @@ namespace vector {
 //===----------------------------------------------------------------------===//
 // Vector transformation options exposed as auxiliary structs.
 //===----------------------------------------------------------------------===//
-/// Enum to control the lowering of `vector.transpose` operations.
-enum class VectorTransposeLowering {
-  /// Lower transpose into element-wise extract and inserts.
-  EltWise = 0,
-  /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
-  /// intrinsics.
-  Flat = 1,
-  /// Lower 2-D transpose to `vector.shuffle`.
-  Shuffle = 2,
-};
-/// Enum to control the lowering of `vector.multi_reduction` operations.
-enum class VectorMultiReductionLowering {
-  /// Lower multi_reduction into outer-reduction and inner-parallel ops.
-  InnerParallel = 0,
-  /// Lower multi_reduction into outer-parallel and inner-reduction ops.
-  InnerReduction = 1,
-};
-/// Enum to control the lowering of `vector.contract` operations.
-enum class VectorContractLowering {
-  /// Progressively lower to finer grained `vector.contract` and dot-products.
-  Dot = 0,
-  /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
-  Matmul = 1,
-  /// Lower to `vector.outerproduct`.
-  OuterProduct = 2,
-  /// Lower contract with all reduction dimensions unrolled to 1 to a vector
-  /// elementwise operations.
-  ParallelArith = 3,
-};
-/// Enum to control the splitting of `vector.transfer` operations into
-/// in-bounds and out-of-bounds variants.
-enum class VectorTransferSplit {
-  /// Do not split vector transfer operations.
-  None = 0,
-  /// Split using in-bounds + out-of-bounds vector.transfer operations.
-  VectorTransfer = 1,
-  /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
-  /// operations.
-  LinalgCopy = 2,
-  /// Do not split vector transfer operation but instead mark it as "in-bounds".
-  ForceInBounds = 3
-};
 /// Structure to control the behavior of vector transform patterns.
 struct VectorTransformsOptions {
   /// Option to control the lowering of vector.contract.

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
new file mode 100644
index 0000000000000..fb1f2ab717687
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -0,0 +1,86 @@
+//===- VectorTransformBase.td - Vector transform ops --------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef VECTOR_TRANSFORMS_BASE
+#define VECTOR_TRANSFORMS_BASE
+
+include "mlir/IR/EnumAttr.td"
+
+// Lower transpose into element-wise extract and inserts.
+def VectorTransposeLowering_Elementwise:
+  I32EnumAttrCase<"EltWise",  0, "eltwise">;
+// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
+// intrinsics.
+def VectorTransposeLowering_FlatTranspose:
+  I32EnumAttrCase<"Flat",  1, "flat_transpose">;
+// Lower 2-D transpose to `vector.shuffle`.
+def VectorTransposeLowering_Shuffle:
+  I32EnumAttrCase<"Shuffle",  2, "shuffle">;
+def VectorTransposeLoweringAttr : I32EnumAttr<
+    "VectorTransposeLowering",
+    "control the lowering of `vector.transpose` operations.",
+    [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
+     VectorTransposeLowering_Shuffle]> {
+  let cppNamespace = "::mlir::vector";
+}
+
+// Lower multi_reduction into outer-reduction and inner-parallel ops.
+def VectorMultiReductionLowering_InnerParallel:
+  I32EnumAttrCase<"InnerParallel", 0, "innerparallel">;
+// Lower multi_reduction into outer-parallel and inner-reduction ops.
+def VectorMultiReductionLowering_InnerReduction:
+  I32EnumAttrCase<"InnerReduction", 1, "innerreduction">;
+def VectorMultiReductionLoweringAttr: I32EnumAttr<
+    "VectorMultiReductionLowering",
+    "control the lowering of `vector.multi_reduction`.",
+  [VectorMultiReductionLowering_InnerParallel,
+   VectorMultiReductionLowering_InnerReduction]> {
+  let cppNamespace = "::mlir::vector";
+}
+
+// Progressively lower to finer grained `vector.contract` and dot-products.
+def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">;
+// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
+def VectorContractLowering_Matmul:
+  I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">;
+// Lower to `vector.outerproduct`.
+def VectorContractLowering_OuterProduct:
+  I32EnumAttrCase<"OuterProduct", 2, "outerproduct">;
+// Lower contract with all reduction dimensions unrolled to 1 to a vector
+// elementwise operations.
+def VectorContractLowering_ParallelArith:
+  I32EnumAttrCase<"ParallelArith", 3, "parallelarith">;
+def VectorContractLoweringAttr: I32EnumAttr<
+    "VectorContractLowering",
+    "control the lowering of `vector.contract` operations.",
+  [VectorContractLowering_Dot, VectorContractLowering_Matmul,
+   VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> {
+  let cppNamespace = "::mlir::vector";
+}
+
+// Do not split vector transfer operations.
+def VectorTransferSplit_None: I32EnumAttrCase<"None", 0, "none">;
+// Split using in-bounds + out-of-bounds vector.transfer operations.
+def VectorTransferSplit_VectorTransfer:
+  I32EnumAttrCase<"VectorTransfer", 1, "vector-transfer">;
+// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy
+// operations.
+def VectorTransferSplit_LinalgCopy:
+  I32EnumAttrCase<"LinalgCopy", 2, "linalg-copy">;
+// Do not split vector transfer operation but instead mark it as "in-bounds".
+def VectorTransferSplit_ForceInBounds:
+  I32EnumAttrCase<"ForceInBounds", 3, "force-in-bounds">;
+def VectorTransferSplitAttr: I32EnumAttr<
+    "VectorTransferSplit",
+    "control the splitting of `vector.transfer` operations into in-bounds"
+    " and out-of-bounds variants.",
+  [VectorTransferSplit_None, VectorTransferSplit_VectorTransfer,
+   VectorTransferSplit_LinalgCopy, VectorTransferSplit_ForceInBounds]> {
+  let cppNamespace = "::mlir::vector";
+}
+#endif

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 79d1ff7040d45..60996b9add614 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -53,32 +53,12 @@ DiagnosedSilenceableFailure transform::LowerVectorsOp::apply(
     MLIRContext *ctx = getContext();
     RewritePatternSet patterns(ctx);
     vector::VectorTransposeLowering vectorTransposeLowering =
-        llvm::StringSwitch<vector::VectorTransposeLowering>(
-            getTransposeLowering())
-            .Case("eltwise", vector::VectorTransposeLowering::EltWise)
-            .Case("flat_transpose", vector::VectorTransposeLowering::Flat)
-            .Case("shuffle", vector::VectorTransposeLowering::Shuffle)
-            .Default(vector::VectorTransposeLowering::EltWise);
+        getTransposeLowering();
     vector::VectorMultiReductionLowering vectorMultiReductionLowering =
-        llvm::StringSwitch<vector::VectorMultiReductionLowering>(
-            getMultireductionLowering())
-            .Case("innerreduction",
-                  vector::VectorMultiReductionLowering::InnerReduction)
-            .Default(vector::VectorMultiReductionLowering::InnerParallel);
+        getMultireductionLowering();
     vector::VectorContractLowering vectorContractLowering =
-        llvm::StringSwitch<vector::VectorContractLowering>(
-            getContractionLowering())
-            .Case("matrixintrinsics", vector::VectorContractLowering::Matmul)
-            .Case("dot", vector::VectorContractLowering::Dot)
-            .Case("outerproduct", vector::VectorContractLowering::OuterProduct)
-            .Default(vector::VectorContractLowering::OuterProduct);
-    vector::VectorTransferSplit vectorTransferSplit =
-        llvm::StringSwitch<vector::VectorTransferSplit>(getSplitTransfers())
-            .Case("none", vector::VectorTransferSplit::None)
-            .Case("linalg-copy", vector::VectorTransferSplit::LinalgCopy)
-            .Case("vector-transfers",
-                  vector::VectorTransferSplit::VectorTransfer)
-            .Default(vector::VectorTransferSplit::None);
+        getContractionLowering();
+    vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers();
 
     vector::VectorTransformsOptions vectorTransformOptions;
     vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering)

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index e67f147399e59..504ea7afe8301 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -3040,3 +3040,9 @@ void mlir::vector::populateVectorScanLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
 }
+
+//===----------------------------------------------------------------------===//
+// TableGen'd enum attribute definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"

diff  --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 2aeda92dfd528..f899a81d1a5d0 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -21,5 +21,5 @@ transform.sequence failures(propagate) {
   transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op
     {bufferize_function_boundaries = true}
   %func = transform.structured.match ops{["func.func"]} in %module_op
-  transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
+  transform.vector.lower_vectors %func multireduction_lowering = "innerreduction"
 }

diff  --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 25c26b63c6107..a753b229576a7 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -22,5 +22,5 @@ transform.sequence failures(propagate) {
   transform.bufferization.one_shot_bufferize %module_op
 
   %func = transform.structured.match ops{["func.func"]} in %module_op
-  transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"}
+  transform.vector.lower_vectors %func multireduction_lowering = "innerreduction"
 }


        


More information about the Mlir-commits mailing list