[Mlir-commits] [mlir] 565ee6a - [mlir][spirv] add support lowering of extract_slice to scalar type

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 7 07:52:27 PDT 2021


Author: thomasraoux
Date: 2021-05-07T07:52:02-07:00
New Revision: 565ee6afc707d5744d0ec90936f0c0564c1acf69

URL: https://github.com/llvm/llvm-project/commit/565ee6afc707d5744d0ec90936f0c0564c1acf69
DIFF: https://github.com/llvm/llvm-project/commit/565ee6afc707d5744d0ec90936f0c0564c1acf69.diff

LOG: [mlir][spirv] add support lowering of extract_slice to scalar type

Differential Revision: https://reviews.llvm.org/D102041

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index edabae72913bf..de9dfa1c4bc17 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -113,9 +113,6 @@ struct VectorExtractStridedSliceOpConvert final
     if (!dstType)
       return failure();
 
-    // Extract vector<1xT> not supported yet.
-    if (dstType.isa<spirv::ScalarType>())
-      return failure();
 
     uint64_t offset = getFirstIntValue(extractOp.offsets());
     uint64_t size = getFirstIntValue(extractOp.sizes());
@@ -125,6 +122,13 @@ struct VectorExtractStridedSliceOpConvert final
 
     Value srcVector = operands.front();
 
+    // Extract vector<1xT> case.
+    if (dstType.isa<spirv::ScalarType>()) {
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
+                                                             srcVector, offset);
+      return success();
+    }
+
     SmallVector<int32_t, 2> indices(size);
     std::iota(indices.begin(), indices.end(), offset);
 

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 9f9657cd56952..4a471a4108531 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -91,8 +91,10 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
 // CHECK-LABEL: func @extract_strided_slice
 //  CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
 //       CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+//       CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<4xf32>
 func @extract_strided_slice(%arg0: vector<4xf32>) {
   %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+  %1 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
   spv.Return
 }
 


        


More information about the Mlir-commits mailing list