[Mlir-commits] [mlir] [mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion (PR #93240)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 14:19:56 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

<details>
<summary>Changes</summary>

Add `vector.interleave` to `spirv.VectorShuffle` conversion, and remove the `vector.interleave` to `vector.shuffle` conversion in `populateVectorToSPIRVPatterns`.

---
Full diff: https://github.com/llvm/llvm-project/pull/93240.diff


1 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+41-5) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
   }
 };
 
+struct VectorInterleaveOpConvert final
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+     // Check the source vector type
+    auto sourceType = interleaveOp.getSourceVectorType();
+    if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported source vector type");
+    }  
+
+    // Check the result vector type
+    auto oldResultType = interleaveOp.getResultVectorType();
+    Type newResultType = getTypeConverter()->convertType(oldResultType);
+    if (!newResultType)
+      return rewriter.notifyMatchFailure(interleaveOp,
+                                         "unsupported result vector type");
+
+    // Interleave the indices
+    int n = sourceType.getNumElements();
+    auto seq = llvm::seq<int64_t>(2 * n);
+    auto indices = llvm::to_vector(llvm::map_range(
+        seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+    // Emit a SPIR-V shuffle.
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+        rewriter.getI32ArrayAttr(indices));
+
+    llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+    
+    return success();
+  }
+};
+
 struct VectorLoadOpConverter final
     : public OpConversionPattern<vector::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
-      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
-      VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+      VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, 
+      VectorInterleaveOpConvert, VectorSplatPattern, 
+      VectorLoadOpConverter, VectorStoreOpConverter>(
       typeConverter, patterns.getContext(), PatternBenefit(1));
 
   // Make sure that the more specialized dot product pattern has higher benefit
   // than the generic one that extracts all elements.
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
-
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(

``````````

</details>


https://github.com/llvm/llvm-project/pull/93240


More information about the Mlir-commits mailing list