[Mlir-commits] [mlir] [mlir][spirv] Handle scalar shuffles in vector to spirv conversion (PR #98809)

Jakub Kuderski llvmlistbot at llvm.org
Sun Jul 14 04:44:09 PDT 2024


https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/98809

These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.

>From ff91e5a761f95a154678b787d1da43b32c573dbf Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sun, 14 Jul 2024 07:40:55 -0400
Subject: [PATCH] [mlir][spirv] Handle scalar shuffles in vector to spirv
 conversion

These may not get canonicalized before conversion to spirv and need to
be handled during vector to spirv conversion. Because spirv does not
support 1-element vectors, we can't emit `spirv.VectorShuffle` and need
to lower this to `spirv.CompositeExtract`.
---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 25 ++++++++++++-------
 .../VectorToSPIRV/vector-to-spirv.mlir        | 24 ++++++++++++++++++
 2 files changed, 40 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getResultVectorType();
+    VectorType oldResultType = shuffleOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
           return cast<IntegerAttr>(attr).getValue().getZExtValue();
         });
 
-    auto oldV1Type = shuffleOp.getV1VectorType();
-    auto oldV2Type = shuffleOp.getV2VectorType();
+    VectorType oldV1Type = shuffleOp.getV1VectorType();
+    VectorType oldV2Type = shuffleOp.getV2VectorType();
 
-    // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
-    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+    // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+    // shuffle.
+    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+        oldResultType.getNumElements() > 1) {
       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
           rewriter.getI32ArrayAttr(mask));
       return success();
     }
 
-    // When at least one of the operands becomes a scalar after type conversion
-    // for SPIR-V, extract all the required elements and construct the result
-    // vector.
+    // When at least one of the operands or the result becomes a scalar after
+    // type conversion for SPIR-V, extract all the required elements and
+    // construct the result vector.
     auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
                                Value scalarOrVec, int32_t idx) -> Value {
       if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
       newOperand = getElementAtIdx(vec, elementIdx);
     }
 
+    // Handle the scalar result corner case.
+    if (newOperands.size() == 1) {
+      rewriter.replaceOp(shuffleOp, newOperands.front());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         shuffleOp, newResultType, newOperands);
-
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>



More information about the Mlir-commits mailing list