[Mlir-commits] [mlir] 04af7d2 - [mlir][spirv][vector] Fix vector shuffle conversion for scalar inputs
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jul 31 17:13:41 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-31T20:13:30-04:00
New Revision: 04af7d224bd8ee2b346525b02ed04f57d8d9d1f8
URL: https://github.com/llvm/llvm-project/commit/04af7d224bd8ee2b346525b02ed04f57d8d9d1f8
DIFF: https://github.com/llvm/llvm-project/commit/04af7d224bd8ee2b346525b02ed04f57d8d9d1f8.diff
LOG: [mlir][spirv][vector] Fix vector shuffle conversion for scalar inputs
Handle both v0 and v1 operands that may be converted to SPIR-V scalars.
Fixes: https://github.com/llvm/llvm-project/issues/64271
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D156717
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c15b99b5a62d3e..57191ce8dd4c94 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -27,7 +27,10 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
+#include <cassert>
+#include <cstdint>
#include <numeric>
using namespace mlir;
@@ -444,24 +447,48 @@ struct VectorShuffleOpConvert final
return rewriter.notifyMatchFailure(shuffleOp,
"unsupported result vector type");
- auto oldSourceType = shuffleOp.getV1VectorType();
- if (oldSourceType.getNumElements() > 1) {
- SmallVector<int32_t, 4> components = llvm::to_vector<4>(
- llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t {
- return cast<IntegerAttr>(attr).getValue().getZExtValue();
- }));
+ SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
+ shuffleOp.getMask(), [](Attribute attr) -> int32_t {
+ return cast<IntegerAttr>(attr).getValue().getZExtValue();
+ });
+
+ auto oldV1Type = shuffleOp.getV1VectorType();
+ auto oldV2Type = shuffleOp.getV2VectorType();
+
+ // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
+ if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
- rewriter.getI32ArrayAttr(components));
+ rewriter.getI32ArrayAttr(mask));
return success();
}
- SmallVector<Value, 2> oldOperands = {adaptor.getV1(), adaptor.getV2()};
- SmallVector<Value, 4> newOperands;
- newOperands.reserve(oldResultType.getNumElements());
- for (const APInt &i : shuffleOp.getMask().getAsValueRange<IntegerAttr>()) {
- newOperands.push_back(oldOperands[i.getZExtValue()]);
+ // When at least one of the operands becomes a scalar after type conversion
+ // for SPIR-V, extract all the required elements and construct the result
+ // vector.
+ auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
+ Value scalarOrVec, int32_t idx) -> Value {
+ if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
+ return rewriter.create<spirv::CompositeExtractOp>(loc, scalarOrVec,
+ idx);
+
+ assert(idx == 0 && "Invalid scalar element index");
+ return scalarOrVec;
+ };
+
+ int32_t numV1Elems = oldV1Type.getNumElements();
+ SmallVector<Value> newOperands(mask.size());
+ for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
+ Value vec = adaptor.getV1();
+ int32_t elementIdx = shuffleIdx;
+ if (elementIdx >= numV1Elems) {
+ vec = adaptor.getV2();
+ elementIdx -= numV1Elems;
+ }
+
+ newOperand = getElementAtIdx(vec, elementIdx);
}
+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 922351265e38b9..310df8030db300 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -442,6 +442,47 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
// -----
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<3xi32>
+// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<3xi32>
+// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[ARG1]][2 : i32] : vector<3xi32>
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[S1]], %[[S2]] : (i32, i32, i32) -> vector<3xi32>
+// CHECK: return %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<3xi32>) -> vector<3xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<1xi32>, vector<3xi32>
+ return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<1xi32>
+// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32
+// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<3xi32>
+// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG0]][2 : i32] : vector<3xi32>
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[S0]], %[[S1]], %[[V1]] : (i32, i32, i32) -> vector<3xi32>
+// CHECK: return %[[RES]]
+func.func @shuffle(%v0 : vector<3xi32>, %v1: vector<1xi32>) -> vector<3xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<3xi32>, vector<1xi32>
+ return %shuffle : vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32>
+// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32
+// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32
+// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (i32, i32) -> vector<2xi32>
+// CHECK: return %[[RES]]
+func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0, 1] : vector<1xi32>, vector<1xi32>
+ return %shuffle : vector<2xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
More information about the Mlir-commits
mailing list