[Mlir-commits] [mlir] [mlir][spirv] Handle scalar shuffles in vector to spirv conversion (PR #98809)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jul 14 04:44:38 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.
---
Full diff: https://github.com/llvm/llvm-project/pull/98809.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+16-9)
- (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+24)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto oldResultType = shuffleOp.getResultVectorType();
+ VectorType oldResultType = shuffleOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});
- auto oldV1Type = shuffleOp.getV1VectorType();
- auto oldV2Type = shuffleOp.getV2VectorType();
+ VectorType oldV1Type = shuffleOp.getV1VectorType();
+ VectorType oldV2Type = shuffleOp.getV2VectorType();
- // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
- if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+ // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+ // shuffle.
+ if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+ oldResultType.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(mask));
return success();
}
- // When at least one of the operands becomes a scalar after type conversion
- // for SPIR-V, extract all the required elements and construct the result
- // vector.
+ // When at least one of the operands or the result becomes a scalar after
+ // type conversion for SPIR-V, extract all the required elements and
+ // construct the result vector.
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
newOperand = getElementAtIdx(vec, elementIdx);
}
+ // Handle the scalar result corner case.
+ if (newOperands.size() == 1) {
+ rewriter.replaceOp(shuffleOp, newOperands.front());
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
-
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
// -----
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/98809
More information about the Mlir-commits
mailing list