[Mlir-commits] [mlir] [MLIR][SPIR-V] Fix bugs in VectorToSPIRV Pass (PR #93189)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 06:18:32 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Angel Zhang (angelz913)

<details>
<summary>Changes</summary>

Separate vector interleave lowering and to-shuffle rewrite patterns from dialect conversions.

---
Full diff: https://github.com/llvm/llvm-project/pull/93189.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (-2) 
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp (+11) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..d05e6962fc8fe 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -830,8 +830,6 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
   patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
                                            PatternBenefit(2));
 
-  // Need this until vector.interleave is handled.
-  vector::populateVectorInterleaveToShufflePatterns(patterns);
 }
 
 void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 1932de1be603b..1de0512783af3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTVECTORTOSPIRV
@@ -36,6 +38,15 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
   Operation *op = getOperation();
 
+  // Rewrite patterns need to be matched separately from the dialect conversion
+  {
+    RewritePatternSet patterns(context);
+    vector::populateVectorInterleaveLoweringPatterns(patterns);
+    vector::populateVectorInterleaveToShufflePatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+
   auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index cddc4ee385357..6fb40d2beff1a 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,18 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 
 // -----
 
+// CHECK-LABEL: func @interleave
+//  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
+//       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
+//       CHECK: return %[[SHUFFLE]]
+func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32>
+{
+  %0 = vector.interleave %a, %b : vector<2xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @reduction_add
 //  CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
 //       CHECK:   %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/93189


More information about the Mlir-commits mailing list