[Mlir-commits] [mlir] 565ee6a - [mlir][spirv] add support lowering of extract_slice to scalar type
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 7 07:52:27 PDT 2021
Author: thomasraoux
Date: 2021-05-07T07:52:02-07:00
New Revision: 565ee6afc707d5744d0ec90936f0c0564c1acf69
URL: https://github.com/llvm/llvm-project/commit/565ee6afc707d5744d0ec90936f0c0564c1acf69
DIFF: https://github.com/llvm/llvm-project/commit/565ee6afc707d5744d0ec90936f0c0564c1acf69.diff
LOG: [mlir][spirv] add support lowering of extract_slice to scalar type
Differential Revision: https://reviews.llvm.org/D102041
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index edabae72913bf..de9dfa1c4bc17 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -113,9 +113,6 @@ struct VectorExtractStridedSliceOpConvert final
if (!dstType)
return failure();
- // Extract vector<1xT> not supported yet.
- if (dstType.isa<spirv::ScalarType>())
- return failure();
uint64_t offset = getFirstIntValue(extractOp.offsets());
uint64_t size = getFirstIntValue(extractOp.sizes());
@@ -125,6 +122,13 @@ struct VectorExtractStridedSliceOpConvert final
Value srcVector = operands.front();
+ // Extract vector<1xT> case.
+ if (dstType.isa<spirv::ScalarType>()) {
+ rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
+ srcVector, offset);
+ return success();
+ }
+
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 9f9657cd56952..4a471a4108531 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -91,8 +91,10 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
// CHECK-LABEL: func @extract_strided_slice
// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
func @extract_strided_slice(%arg0: vector<4xf32>) {
%0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
spv.Return
}
More information about the Mlir-commits
mailing list