[Mlir-commits] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)

Angel Zhang llvmlistbot at llvm.org
Thu May 23 14:28:40 PDT 2024


https://github.com/angelz913 updated https://github.com/llvm/llvm-project/pull/93240

>From 98b4b90759a60b0d2b4a54f7f012ab4e5ef04cde Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:09:49 +0000
Subject: [PATCH 1/3] [mlir][spirv] Add vector.interleave to
 spirv.VectorShuffle conversion

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 46 +++++++++++++++++--
 1 file changed, 41 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+     // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }  
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(llvm::map_range(
+        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+    
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 
+      VectorInterleaveOpConvert, VectorSplatPattern, 
+      VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

>From 0622fae6857f836b27d5b32e3f17d5f90c752401 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:25:28 +0000
Subject: [PATCH 2/3] Fix style

---
 .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp   | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index b86ebe1a4bb54..c01ff41d8b5e1 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -585,12 +585,12 @@ struct VectorInterleaveOpConvert final
   LogicalResult
   matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-     // Check the source vector type
+    // Check the source vector type
     auto sourceType = interleaveOp.getSourceVectorType();
     if (sourceType.getRank() != 1 || sourceType.isScalable()) {
       return rewriter.notifyMatchFailure(interleaveOp,
                                          "unsupported source vector type");
-    }  
+    }
 
     // Check the result vector type
     auto oldResultType = interleaveOp.getResultVectorType();
@@ -602,8 +602,8 @@ struct VectorInterleaveOpConvert final
     // Interleave the indices
     int n = sourceType.getNumElements();
     auto seq = llvm::seq<int64_t>(2 * n);
-    auto indices = llvm::to_vector(llvm::map_range(
-        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+    auto indices = llvm::to_vector(
+      llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
 
     // Emit a SPIR-V shuffle.
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
@@ -859,10 +859,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 
-      VectorInterleaveOpConvert, VectorSplatPattern, 
-      VectorLoadOpConverter, VectorStoreOpConverter>(
-      typeConverter, patterns.getContext(), PatternBenefit(1));
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+      VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
+      VectorStoreOpConverter>(typeConverter, patterns.getContext(),
+                              PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.

>From 65cb04de63cdcc98be743a7ae4d189b41ab53420 Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Thu, 23 May 2024 21:28:28 +0000
Subject: [PATCH 3/3] Remove debug statement

---
 mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c01ff41d8b5e1..c10cfb5f3bb2b 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -609,8 +609,6 @@ struct VectorInterleaveOpConvert final
     rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
         interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
         rewriter.getI32ArrayAttr(indices));
-
-    llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
     
     return success();
   }



More information about the Mlir-commits mailing list