[Mlir-commits] [mlir] 4710750 - [mlir][spirv] Support size-1 vector inserts during conversion

Lei Zhang llvmlistbot at llvm.org
Fri Jan 21 10:57:51 PST 2022


Author: Lei Zhang
Date: 2022-01-21T13:56:26-05:00
New Revision: 4710750854cee1fdadf5f3381e9431655056b646

URL: https://github.com/llvm/llvm-project/commit/4710750854cee1fdadf5f3381e9431655056b646
DIFF: https://github.com/llvm/llvm-project/commit/4710750854cee1fdadf5f3381e9431655056b646.diff

LOG: [mlir][spirv] Support size-1 vector inserts during conversion

Differential Revision: https://reviews.llvm.org/D115517

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 27037cb4b6f2a..051b691011d8a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -157,6 +157,13 @@ struct VectorInsertOpConvert final
   LogicalResult
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Special case for inserting scalar values into size-1 vectors.
+    if (insertOp.getSourceType().isIntOrFloat() &&
+        insertOp.getDestVectorType().getNumElements() == 1) {
+      rewriter.replaceOp(insertOp, adaptor.source());
+      return success();
+    }
+
     if (insertOp.getSourceType().isa<VectorType>() ||
         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
       return failure();
@@ -209,20 +216,23 @@ struct VectorInsertStridedSliceOpConvert final
     Value srcVector = adaptor.getOperands().front();
     Value dstVector = adaptor.getOperands().back();
 
-    // Insert scalar values not supported yet.
-    if (srcVector.getType().isa<spirv::ScalarType>() ||
-        dstVector.getType().isa<spirv::ScalarType>())
-      return failure();
-
     uint64_t stride = getFirstIntValue(insertOp.strides());
     if (stride != 1)
       return failure();
+    uint64_t offset = getFirstIntValue(insertOp.offsets());
+
+    if (srcVector.getType().isa<spirv::ScalarType>()) {
+      assert(!dstVector.getType().isa<spirv::ScalarType>());
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, dstVector.getType(), srcVector, dstVector,
+          rewriter.getI32ArrayAttr(offset));
+      return success();
+    }
 
     uint64_t totalSize =
         dstVector.getType().cast<VectorType>().getNumElements();
     uint64_t insertSize =
         srcVector.getType().cast<VectorType>().getNumElements();
-    uint64_t offset = getFirstIntValue(insertOp.offsets());
 
     SmallVector<int32_t, 2> indices(totalSize);
     std::iota(indices.begin(), indices.end(), 0);

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 8f5cf197713d2..7a3e4b3289729 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -61,6 +61,17 @@ func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: @insert_size1_vector
+//  CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32
+//       CHECK:   %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]]
+//       CHECK:   return %[[R]]
+func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
+  %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
+	return %1 : vector<1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
@@ -139,6 +150,17 @@ func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector
 
 // -----
 
+// CHECK-LABEL: @insert_size1_vector
+//  CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32>
+//       CHECK:   %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]]
+//       CHECK:   spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
+func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
+  %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
+	return %1 : vector<3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fma
 //  CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
 //       CHECK:   spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>


        


More information about the Mlir-commits mailing list