[Mlir-commits] [llvm] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)

Angel Zhang llvmlistbot at llvm.org
Mon May 27 06:18:09 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/93240

>From cdc9def9d70feb5e3acaea7491c31d06039460d0 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:09:49 +0000
Subject: [PATCH 1/4] [mlir][spirv] Add vector.interleave to
 spirv.VectorShuffle conversion

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 44 ++++++++++++++++---
 1 file changed, 39 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..95464ef6d438e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(
+        llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
+      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
+                              PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

>From 51b25d480e8661b124dceaa7dcb2840320bc01e7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 27 May 2024 08:33:14 -0400
Subject: [PATCH 2/4] Use VectorType for sourceType

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 95464ef6d438e..aa3670f81fea3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -586,7 +586,7 @@ struct VectorInterleaveOpConvert final
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check the source vector type
-    auto sourceType = interleaveOp.getSourceVectorType();
+    VectorType sourceType = interleaveOp.getSourceVectorType();
     if (sourceType.getRank() != 1 || sourceType.isScalable()) {
       return rewriter.notifyMatchFailure(interleaveOp,
                                          "unsupported source vector type");

>From fedd990a69c8ba1b79c4905c29467d9e352ba96c Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 12:42:35 +0000
Subject: [PATCH 3/4] Use VectorType for oldResultType

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index aa3670f81fea3..0af0595eebe0d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -593,7 +593,7 @@ struct VectorInterleaveOpConvert final
     }
 
     // Check the result vector type
-    auto oldResultType = interleaveOp.getResultVectorType();
+    VectorType oldResultType = interleaveOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(interleaveOp,

>From e84085ecde10db678805185f5b62147ff1e11810 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 13:17:57 +0000
Subject: [PATCH 4/4] Remove dependencies in cmake and bazel

---
 .../Vector/TransformOps/VectorTransformOps.td | 14 -------
 .../Vector/Transforms/LoweringPatterns.h      |  3 --
 .../Conversion/VectorToSPIRV/CMakeLists.txt   |  1 -
 .../VectorToSPIRV/VectorToSPIRV.cpp           |  1 -
 .../TransformOps/VectorTransformOps.cpp       |  5 ---
 .../Transforms/LowerVectorInterleave.cpp      | 41 -------------------
 .../Vector/vector-interleave-to-shuffle.mlir  | 21 ----------
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 -
 8 files changed, 87 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..f6371f39c3944 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,20 +306,6 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
-def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
-    "apply_patterns.vector.interleave_to_shuffle",
-    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
-  let description = [{
-    Indicates that 1D vector interleave operations should be rewritten as
-    vector shuffle operations.
-
-    This is motivated by some current codegen backends not handling vector
-    interleave operations.
-  }];
-
-  let assemblyFormat = "attr-dict";
-}
-
 def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.rewrite_narrow_types",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8fd9904fabc0e..350d2777cadf5 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,9 +273,6 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
                                               int64_t targetRank = 1,
                                               PatternBenefit benefit = 1);
 
-void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
-                                               PatternBenefit benefit = 1);
-
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index 113983146f5be..bb9f793d7fe0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
   MLIRSPIRVDialect
   MLIRSPIRVConversion
   MLIRVectorDialect
-  MLIRVectorTransforms
   MLIRTransforms
   )
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0af0595eebe0d..f64401559f834 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,7 +18,6 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..885644864c0f7 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,11 +164,6 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
   vector::populateVectorInterleaveLoweringPatterns(patterns);
 }
 
-void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
-    RewritePatternSet &patterns) {
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
-}
-
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 5326760c9b4eb..3a456076f8fba 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/LogicalResult.h"
 
 #define DEBUG_TYPE "vector-interleave-lowering"
 
@@ -78,49 +77,9 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
-/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
-/// applicable: `sourceType` must be 1D and non-scalable.
-///
-/// Example:
-///
-/// ```mlir
-/// vector.interleave %a, %b : vector<7xi16>
-/// ```
-///
-/// Is rewritten into:
-///
-/// ```mlir
-/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
-///   : vector<7xi16>, vector<7xi16>
-/// ```
-class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
-public:
-  InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
-      : OpRewritePattern(context, benefit) {};
-
-  LogicalResult matchAndRewrite(vector::InterleaveOp op,
-                                PatternRewriter &rewriter) const override {
-    VectorType sourceType = op.getSourceVectorType();
-    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
-      return failure();
-    }
-    int64_t n = sourceType.getNumElements();
-    auto seq = llvm::seq<int64_t>(2 * n);
-    auto zip = llvm::to_vector(llvm::map_range(
-        seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; }));
-    rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
   patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
-
-void mlir::vector::populateVectorInterleaveToShufflePatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
-}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
deleted file mode 100644
index ed3b3396bf3ea..0000000000000
--- a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
+++ /dev/null
@@ -1,21 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
-
-// CHECK-LABEL: @vector_interleave_to_shuffle
-func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
-{
-  %0 = vector.interleave %a, %b : vector<7xi16>
-  return %0 : vector<14xi16>
-}
-// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %f = transform.structured.match ops{["func.func"]} in %module_op
-      : (!transform.any_op) -> !transform.any_op
-
-    transform.apply_patterns to %f {
-      transform.apply_patterns.vector.interleave_to_shuffle
-    } : !transform.any_op
-    transform.yield
-  }
-}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 614fe511d43a5..0c5fa11bc85ec 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4979,7 +4979,6 @@ cc_library(
         ":VectorToLLVM",
         ":VectorToSCF",
         ":VectorTransformOpsIncGen",
-        ":VectorTransforms",
         ":X86VectorTransforms",
     ],
 )



More information about the Mlir-commits mailing list