[Mlir-commits] [llvm] [mlir] [mlir][vector] Add Vector-dialect interleave-to-shuffle pattern, enable in VectorToSPIRV (PR #92012)

Benoit Jacob llvmlistbot at llvm.org
Mon May 13 12:32:17 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/92012

>From 6426a940d6077200749531b40de845ea53f6b621 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Fri, 10 May 2024 16:12:51 -0400
Subject: [PATCH 1/2] add interleave-to-shuffle pattern

---
 .../Vector/TransformOps/VectorTransformOps.td | 14 +++++++
 .../Vector/Transforms/LoweringPatterns.h      |  3 ++
 .../Conversion/VectorToSPIRV/CMakeLists.txt   |  1 +
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  4 ++
 .../TransformOps/VectorTransformOps.cpp       |  5 +++
 .../Transforms/LowerVectorInterleave.cpp      | 41 +++++++++++++++++++
 .../Vector/vector-interleave-to-shuffle.mlir  | 21 ++++++++++
 7 files changed, 89 insertions(+)
 create mode 100644 mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.interleave_to_shuffle",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that 1D vector interleave operations should be rewritten as
+    vector shuffle operations.
+
+    This is motivated by some current codegen backends not handling vector
+    interleave operations.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.rewrite_narrow_types",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
                                               int64_t targetRank = 1,
                                               PatternBenefit benefit = 1);
 
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+                                               PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index bb9f793d7fe0f..113983146f5be 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
   MLIRSPIRVDialect
   MLIRSPIRVConversion
   MLIRVectorDialect
+  MLIRVectorTransforms
   MLIRTransforms
   )
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 868a3521e7a0f..c2dd37f481466 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
+
+  // Need this until vector.interleave is handled.
+  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
   vector::populateVectorInterleaveLoweringPatterns(patterns);
 }
 
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..5326760c9b4eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
 
 #define DEBUG_TYPE "vector-interleave-lowering"
 
@@ -77,9 +78,49 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
+/// applicable: `sourceType` must be 1D and non-scalable.
+///
+/// Example:
+///
+/// ```mlir
+/// vector.interleave %a, %b : vector<7xi16>
+/// ```
+///
+/// Is rewritten into:
+///
+/// ```mlir
+/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
+///   : vector<7xi16>, vector<7xi16>
+/// ```
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+  InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit) {};
+
+  LogicalResult matchAndRewrite(vector::InterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = op.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return failure();
+    }
+    int64_t n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto zip = llvm::to_vector(llvm::map_range(
+        seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; }));
+    rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
   patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..ed3b3396bf3ea
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+  %0 = vector.interleave %a, %b : vector<7xi16>
+  return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.interleave_to_shuffle
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 14804fe4d7530ed720fe43c8fc459ca5467bfae5 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Mon, 13 May 2024 15:32:06 -0400
Subject: [PATCH 2/2] bazel

---
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 +
 1 file changed, 1 insertion(+)

diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 6304b7b548d81..debd8daf55497 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5877,6 +5877,7 @@ cc_library(
         ":Support",
         ":TransformUtils",
         ":VectorDialect",
+        ":VectorTransforms",
         "//llvm:Support",
     ],
 )



More information about the Mlir-commits mailing list