[Mlir-commits] [mlir] b2e72cd - [mlir][spirv] Support conversion of extract op from vector<1xT> type
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 20 09:12:14 PDT 2021
Author: thomasraoux
Date: 2021-04-20T09:11:41-07:00
New Revision: b2e72cd38de859194b18d598fdfe704315be3d36
URL: https://github.com/llvm/llvm-project/commit/b2e72cd38de859194b18d598fdfe704315be3d36
DIFF: https://github.com/llvm/llvm-project/commit/b2e72cd38de859194b18d598fdfe704315be3d36.diff
LOG: [mlir][spirv] Support conversion of extract op from vector<1xT> type
Differential Revision: https://reviews.llvm.org/D100814
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 4cfcb4148cf0f..edabae72913bf 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -89,6 +89,11 @@ struct VectorExtractOpConvert final
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
+ if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
+ rewriter.replaceOp(extractOp, adaptor.vector());
+ return success();
+ }
+
int32_t id = getFirstIntValue(extractOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 836d3853d3351..9f9657cd56952 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -40,6 +40,24 @@ func @extract(%arg0 : vector<2xf32>) {
// -----
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
+
+// CHECK-LABEL: func @extract_scalar
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf16>
+// CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+// CHECK: %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32
+// CHECK: spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32>
+func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) {
+ %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32>
+ %1 = vector.extract %0[0] : vector<1xf32>
+ %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32>
+ spv.Return
+}
+
+} // end module
+
+// -----
+
// CHECK-LABEL: extract_insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
More information about the Mlir-commits
mailing list