[Mlir-commits] [mlir] 4710750 - [mlir][spirv] Support size-1 vector inserts during conversion
Lei Zhang
llvmlistbot at llvm.org
Fri Jan 21 10:57:51 PST 2022
Author: Lei Zhang
Date: 2022-01-21T13:56:26-05:00
New Revision: 4710750854cee1fdadf5f3381e9431655056b646
URL: https://github.com/llvm/llvm-project/commit/4710750854cee1fdadf5f3381e9431655056b646
DIFF: https://github.com/llvm/llvm-project/commit/4710750854cee1fdadf5f3381e9431655056b646.diff
LOG: [mlir][spirv] Support size-1 vector inserts during conversion
Differential Revision: https://reviews.llvm.org/D115517
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 27037cb4b6f2a..051b691011d8a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -157,6 +157,13 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ // Special case for inserting scalar values into size-1 vectors.
+ if (insertOp.getSourceType().isIntOrFloat() &&
+ insertOp.getDestVectorType().getNumElements() == 1) {
+ rewriter.replaceOp(insertOp, adaptor.source());
+ return success();
+ }
+
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
@@ -209,20 +216,23 @@ struct VectorInsertStridedSliceOpConvert final
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
- // Insert scalar values not supported yet.
- if (srcVector.getType().isa<spirv::ScalarType>() ||
- dstVector.getType().isa<spirv::ScalarType>())
- return failure();
-
uint64_t stride = getFirstIntValue(insertOp.strides());
if (stride != 1)
return failure();
+ uint64_t offset = getFirstIntValue(insertOp.offsets());
+
+ if (srcVector.getType().isa<spirv::ScalarType>()) {
+ assert(!dstVector.getType().isa<spirv::ScalarType>());
+ rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+ insertOp, dstVector.getType(), srcVector, dstVector,
+ rewriter.getI32ArrayAttr(offset));
+ return success();
+ }
uint64_t totalSize =
dstVector.getType().cast<VectorType>().getNumElements();
uint64_t insertSize =
srcVector.getType().cast<VectorType>().getNumElements();
- uint64_t offset = getFirstIntValue(insertOp.offsets());
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 8f5cf197713d2..7a3e4b3289729 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -61,6 +61,17 @@ func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// -----
+// CHECK-LABEL: @insert_size1_vector
+// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
+// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
+// CHECK: return %[[R]]
+func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
+ %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// -----
+
// CHECK-LABEL: @extract_element
// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
@@ -139,6 +150,17 @@ func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector
// -----
+// CHECK-LABEL: @insert_size1_vector
+// CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32>
+// CHECK: %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]]
+// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
+func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
+ %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
More information about the Mlir-commits
mailing list