[Mlir-commits] [mlir] Add Vector-dialect interleave-to-shuffle pattern (PR #91800)
Benoit Jacob
llvmlistbot at llvm.org
Fri May 10 13:22:25 PDT 2024
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/91800
>From 17f0a9ba738801b56b2a9a5da3fed128c018c3a2 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 10 May 2024 16:12:51 -0400
Subject: [PATCH 1/3] add interleave-to-shuffle pattern
---
.../Vector/TransformOps/VectorTransformOps.td | 14 ++++++++++
.../Vector/Transforms/LoweringPatterns.h | 3 +++
.../TransformOps/VectorTransformOps.cpp | 5 ++++
.../Transforms/LowerVectorInterleave.cpp | 26 +++++++++++++++++++
.../Vector/vector-interleave-to-shuffle.mlir | 21 +++++++++++++++
5 files changed, 69 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.interleave_to_shuffle",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that 1D vector interleave operations should be rewritten as
+ vector shuffle operations.
+
+ This is motivated by some current codegen backends not handling vector
+ interleave operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..d7eef637d7c38 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
vector::populateVectorInterleaveLoweringPatterns(patterns);
}
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..2388a952cdb0d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-interleave-lowering"
@@ -77,9 +78,34 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ InterleaveToShuffle(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit) {};
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ if (sourceType.getRank() != 1) {
+ return failure();
+ }
+ rewriter.replaceOpWithNewOp<ShuffleOp>(
+ op, op.getLhs(), op.getRhs(),
+ llvm::map_to_vector(llvm::seq<int64_t>(2 * sourceType.getNumElements()),
+ [](int64_t i) { return i / 2; }));
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..0b039ba78289c
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+ %0 = vector.interleave %a, %b : vector<7xi16>
+ return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.interleave_to_shuffle
+ } : !transform.any_op
+ transform.yield
+ }
+}
>From e5ddd6c21d8f63454a16ebf0a1fcaa00c1000d91 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 10 May 2024 16:21:05 -0400
Subject: [PATCH 2/3] clang-format
---
.../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +-
mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp | 5 ++---
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index d7eef637d7c38..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -274,7 +274,7 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ PatternBenefit benefit = 1);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 2388a952cdb0d..6ba11535f0893 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -80,9 +80,8 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
public:
- InterleaveToShuffle(MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpRewritePattern(context, benefit) {};
+ InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit){};
LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
>From 2c8e80a9d065776353af76b4cf23162d02df2ff2 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 10 May 2024 16:22:14 -0400
Subject: [PATCH 3/3] newline
---
mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 6ba11535f0893..35557e05bb45e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -107,4 +107,4 @@ void mlir::vector::populateVectorInterleaveLoweringPatterns(
void mlir::vector::populateVectorInterleaveToShufflePatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
-}
\ No newline at end of file
+}
More information about the Mlir-commits
mailing list