[Mlir-commits] [llvm] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)

Angel Zhang llvmlistbot at llvm.org
Mon May 27 13:59:09 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/93240

>From 928f7282113d16e0395ca473565d6784d46c0056 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:09:49 +0000
Subject: [PATCH 1/9] [mlir][spirv] Add vector.interleave to
 spirv.VectorShuffle conversion

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 44 ++++++++++++++++---
 1 file changed, 39 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..95464ef6d438e 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,42 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(
+        llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -822,16 +858,14 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
       VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
+      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
+                              PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

>From 1ff747a49e1e71b73a5de5fec5c6f4db430f9830 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 27 May 2024 08:33:14 -0400
Subject: [PATCH 2/9] Use VectorType for sourceType

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 95464ef6d438e..aa3670f81fea3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -586,7 +586,7 @@ struct VectorInterleaveOpConvert final
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Check the source vector type
-    auto sourceType = interleaveOp.getSourceVectorType();
+    VectorType sourceType = interleaveOp.getSourceVectorType();
     if (sourceType.getRank() != 1 || sourceType.isScalable()) {
       return rewriter.notifyMatchFailure(interleaveOp,
                                          "unsupported source vector type");

>From e6eb044c38a88e965ffebe60877e571fbd97d0a1 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 12:42:35 +0000
Subject: [PATCH 3/9] Use VectorType for oldResultType

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index aa3670f81fea3..0af0595eebe0d 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -593,7 +593,7 @@ struct VectorInterleaveOpConvert final
     }
 
     // Check the result vector type
-    auto oldResultType = interleaveOp.getResultVectorType();
+    VectorType oldResultType = interleaveOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(interleaveOp,

>From 7e9ea7fb8768582432c2b53dd10c0264ea23357c Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 16:06:10 +0000
Subject: [PATCH 4/9] Handle one-element input vector case and remove
 cmake/bazel dependencies

---
 mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt  |  1 -
 .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp    | 15 ++++++++++++++-
 .../Conversion/VectorToSPIRV/vector-to-spirv.mlir | 13 +++++++++++++
 utils/bazel/llvm-project-overlay/mlir/BUILD.bazel |  1 -
 4 files changed, 27 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index 113983146f5be..bb9f793d7fe0f 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
   MLIRSPIRVDialect
   MLIRSPIRVConversion
   MLIRVectorDialect
-  MLIRVectorTransforms
   MLIRTransforms
   )
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 0af0595eebe0d..a63ef5ab451eb 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -601,6 +601,19 @@ struct VectorInterleaveOpConvert final
 
     // Interleave the indices
     int n = sourceType.getNumElements();
+
+    // Input vectors of size 1 are converted to scalars by the type converter. 
+    // We cannot use spirv::VectorShuffleOp directly in this case, and need to 
+    // use spirv::CompositeConstructOp.
+    if (n == 1) {
+      SmallVector<Value> newOperands(2);
+      newOperands[0] = adaptor.getLhs();
+      newOperands[1] = adaptor.getRhs();
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+          interleaveOp, newResultType, newOperands);
+      return success();
+    }
+
     auto seq = llvm::seq<int64_t>(2 * n);
     auto indices = llvm::to_vector(
         llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
@@ -609,7 +622,7 @@ struct VectorInterleaveOpConvert final
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
         interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
         rewriter.getI32ArrayAttr(indices));
-
+    
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index b24088d951259..f52e771f1d4a8 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @interleave_size1
+// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>)
+//       CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
+//       CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+//       CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
+//       CHECK: return %[[RES]]
+func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
+  %0 = vector.interleave %a, %b : vector<1xf32>
+  return %0 : vector<2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a7bbe459fd9d7..f31f75ca5c74a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4976,7 +4976,6 @@ cc_library(
         ":VectorToLLVM",
         ":VectorToSCF",
         ":VectorTransformOpsIncGen",
-        ":VectorTransforms",
         ":X86VectorTransforms",
     ],
 )

>From ff34b53a136f8af220a444d98426a817dfff9224 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 16:39:02 +0000
Subject: [PATCH 5/9] Remove check for source type and reformat code

---
 .../lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 14 ++++----------
 1 file changed, 4 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index a63ef5ab451eb..69f89d087dd3c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,13 +585,6 @@ struct VectorInterleaveOpConvert final
   LogicalResult
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check the source vector type
-    VectorType sourceType = interleaveOp.getSourceVectorType();
-    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
-      return rewriter.notifyMatchFailure(interleaveOp,
-                                         "unsupported source vector type");
-    }
-
     // Check the result vector type
     VectorType oldResultType = interleaveOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
@@ -600,10 +593,11 @@ struct VectorInterleaveOpConvert final
                                          "unsupported result vector type");
 
     // Interleave the indices
+    VectorType sourceType = interleaveOp.getSourceVectorType();
     int n = sourceType.getNumElements();
 
-    // Input vectors of size 1 are converted to scalars by the type converter. 
-    // We cannot use spirv::VectorShuffleOp directly in this case, and need to 
+    // Input vectors of size 1 are converted to scalars by the type converter.
+    // We cannot use spirv::VectorShuffleOp directly in this case, and need to
     // use spirv::CompositeConstructOp.
     if (n == 1) {
       SmallVector<Value> newOperands(2);
@@ -622,7 +616,7 @@ struct VectorInterleaveOpConvert final
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
         interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
         rewriter.getI32ArrayAttr(indices));
-    
+
     return success();
   }
 };

>From a7e9433eedf1bd0b63371670d0a367791c358838 Mon Sep 17 00:00:00 2001
From: Angel Zhang <anzhouzhang913 at gmail.com>
Date: Mon, 27 May 2024 13:39:33 -0400
Subject: [PATCH 6/9] Reformat code

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 69f89d087dd3c..043b0741729d6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,24 +585,22 @@ struct VectorInterleaveOpConvert final
   LogicalResult
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // Check the result vector type
+    // Check the result vector type.
     VectorType oldResultType = interleaveOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(interleaveOp,
                                          "unsupported result vector type");
 
-    // Interleave the indices
+    // Interleave the indices.
     VectorType sourceType = interleaveOp.getSourceVectorType();
     int n = sourceType.getNumElements();
 
     // Input vectors of size 1 are converted to scalars by the type converter.
-    // We cannot use spirv::VectorShuffleOp directly in this case, and need to
-    // use spirv::CompositeConstructOp.
+    // We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
+    // use `spirv::CompositeConstructOp`.
     if (n == 1) {
-      SmallVector<Value> newOperands(2);
-      newOperands[0] = adaptor.getLhs();
-      newOperands[1] = adaptor.getRhs();
+      Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
           interleaveOp, newResultType, newOperands);
       return success();

>From a8e806ad2e6df3f817f30b04677487a59fb86bd1 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 20:37:22 +0000
Subject: [PATCH 7/9] Use llvm::map_to_vector

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 043b0741729d6..7c17042917ff9 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -607,8 +607,8 @@ struct VectorInterleaveOpConvert final
     }
 
     auto seq = llvm::seq<int64_t>(2 * n);
-    auto indices = llvm::to_vector(
-        llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+    auto indices = llvm::map_to_vector(
+      seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
 
     // Emit a SPIR-V shuffle.
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(

>From 9a688a71a12f7f0ea3028c66d2a7003ddb33b888 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 20:53:31 +0000
Subject: [PATCH 8/9] Modify vector.interleave assembly format in the LIT test

---
 mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index f52e771f1d4a8..2592d0fc04111 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -501,7 +501,7 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
 //       CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
 //       CHECK: return %[[RES]]
 func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
-  %0 = vector.interleave %a, %b : vector<1xf32>
+  %0 = vector.interleave %a, %b : vector<1xf32> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
 

>From 08e191aa1dc042a62cb6286ea71083fbaee606a7 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Mon, 27 May 2024 20:58:57 +0000
Subject: [PATCH 9/9] Reformat code

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 7c17042917ff9..a9ed25fbfbe0c 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -608,7 +608,7 @@ struct VectorInterleaveOpConvert final
 
     auto seq = llvm::seq<int64_t>(2 * n);
     auto indices = llvm::map_to_vector(
-      seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
+        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
 
     // Emit a SPIR-V shuffle.
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(



More information about the Mlir-commits mailing list