[Mlir-commits] [mlir] b2e72cd - [mlir][spirv] Support conversion of extract op from vector<1xT> type

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 20 09:12:14 PDT 2021


Author: thomasraoux
Date: 2021-04-20T09:11:41-07:00
New Revision: b2e72cd38de859194b18d598fdfe704315be3d36

URL: https://github.com/llvm/llvm-project/commit/b2e72cd38de859194b18d598fdfe704315be3d36
DIFF: https://github.com/llvm/llvm-project/commit/b2e72cd38de859194b18d598fdfe704315be3d36.diff

LOG: [mlir][spirv] Support conversion of extract op from vector<1xT> type

Differential Revision: https://reviews.llvm.org/D100814

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 4cfcb4148cf0f..edabae72913bf 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -89,6 +89,11 @@ struct VectorExtractOpConvert final
       return failure();
 
     vector::ExtractOp::Adaptor adaptor(operands);
+    if (adaptor.vector().getType().isa<spirv::ScalarType>()) {
+      rewriter.replaceOp(extractOp, adaptor.vector());
+      return success();
+    }
+
     int32_t id = getFirstIntValue(extractOp.position());
     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
         extractOp, adaptor.vector(), id);

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 836d3853d3351..9f9657cd56952 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -40,6 +40,24 @@ func @extract(%arg0 : vector<2xf32>) {
 
 // -----
 
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
+
+// CHECK-LABEL: func @extract_scalar
+//  CHECK-SAME: %[[ARG0:.+]]: vector<2xf16>
+//  CHECK-SAME: %[[ARG1:.+]]: vector<4xf32>
+//       CHECK:   %[[S:.+]] = spv.Bitcast %[[ARG0]] : vector<2xf16> to f32
+//       CHECK:   spv.CompositeInsert %[[S]], %[[ARG1]][0 : i32] : f32 into vector<4xf32>
+func @extract_scalar(%arg0 : vector<2xf16>, %arg1 : vector<4xf32>) {
+  %0 = vector.bitcast %arg0 : vector<2xf16> to vector<1xf32>
+  %1 = vector.extract %0[0] : vector<1xf32>
+  %2 = vector.insert %1, %arg1[0] : f32 into vector<4xf32>
+  spv.Return
+}
+
+} // end module
+
+// -----
+
 // CHECK-LABEL: extract_insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>
 //       CHECK:   %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>


        


More information about the Mlir-commits mailing list