[Mlir-commits] [mlir] [mlir][spirv] Handle scalar shuffles in vector to spirv conversion (PR #98809)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 14 04:44:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.

---
Full diff: https://github.com/llvm/llvm-project/pull/98809.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+16-9) 
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getResultVectorType();
+    VectorType oldResultType = shuffleOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
           return cast<IntegerAttr>(attr).getValue().getZExtValue();
         });
 
-    auto oldV1Type = shuffleOp.getV1VectorType();
-    auto oldV2Type = shuffleOp.getV2VectorType();
+    VectorType oldV1Type = shuffleOp.getV1VectorType();
+    VectorType oldV2Type = shuffleOp.getV2VectorType();
 
-    // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
-    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+    // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+    // shuffle.
+    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+        oldResultType.getNumElements() > 1) {
       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
           rewriter.getI32ArrayAttr(mask));
       return success();
     }
 
-    // When at least one of the operands becomes a scalar after type conversion
-    // for SPIR-V, extract all the required elements and construct the result
-    // vector.
+    // When at least one of the operands or the result becomes a scalar after
+    // type conversion for SPIR-V, extract all the required elements and
+    // construct the result vector.
     auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
                                Value scalarOrVec, int32_t idx) -> Value {
       if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
       newOperand = getElementAtIdx(vec, elementIdx);
     }
 
+    // Handle the scalar result corner case.
+    if (newOperands.size() == 1) {
+      rewriter.replaceOp(shuffleOp, newOperands.front());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         shuffleOp, newResultType, newOperands);
-
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/98809


More information about the Mlir-commits mailing list