[Mlir-commits] [mlir] 3035e67 - [mlir][spirv] Add VectorInsertDynamicOp and vector.insertelement lowering

Artur Bialas llvmlistbot at llvm.org
Tue Nov 10 00:49:58 PST 2020


Author: Artur Bialas
Date: 2020-11-10T09:49:12+01:00
New Revision: 3035e676a3880a6f207f8668bde7e47248520e07

URL: https://github.com/llvm/llvm-project/commit/3035e676a3880a6f207f8668bde7e47248520e07
DIFF: https://github.com/llvm/llvm-project/commit/3035e676a3880a6f207f8668bde7e47248520e07.diff

LOG: [mlir][spirv] Add VectorInsertDynamicOp and vector.insertelement lowering

VectorInsertDynamicOp in SPIRV dialect
conversion from vector.insertelement to spirv VectorInsertDynamicOp

Differential Revision: https://reviews.llvm.org/D90927

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir
    mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
    mlir/test/Dialect/SPIRV/composite-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index e7f1a3888ab3..54e5efe5f295 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3177,6 +3177,7 @@ def SPV_OC_OpAccessChain               : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
 def SPV_OC_OpVectorExtractDynamic      : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
+def SPV_OC_OpVectorInsertDynamic       : I32EnumAttrCase<"OpVectorInsertDynamic", 78>;
 def SPV_OC_OpCompositeConstruct        : I32EnumAttrCase<"OpCompositeConstruct", 80>;
 def SPV_OC_OpCompositeExtract          : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpCompositeInsert           : I32EnumAttrCase<"OpCompositeInsert", 82>;
@@ -3310,9 +3311,9 @@ def SPV_OpcodeAttr :
       SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
       SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
-      SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain,
-      SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
-      SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct,
+      SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+      SPV_OC_OpMemberDecorate, SPV_OC_OpVectorExtractDynamic,
+      SPV_OC_OpVectorInsertDynamic, SPV_OC_OpCompositeConstruct,
       SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
       SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
       SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
index c1f76d0fee20..8c9eac195878 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -171,7 +171,8 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
 // -----
 
 def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
-            [NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'",
+            [NoSideEffect,
+            TypesMatchWith<"type of 'result' matches element type of 'vector'",
                      "vector", "result",
                      "$_self.cast<mlir::VectorType>().getElementType()">]> {
   let summary = [{
@@ -225,4 +226,67 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
 
 // -----
 
+def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic",
+        [NoSideEffect,
+        TypesMatchWith<"type of 'component' matches element type of 'vector'",
+                "vector", "component",
+                "$_self.cast<mlir::VectorType>().getElementType()">,
+                AllTypesMatch<["vector", "result"]>]> {
+  let summary = [{
+    Make a copy of a vector, with a single, variably selected, component
+    modified.
+  }];
+
+  let description = [{
+    Result Type must be an OpTypeVector.
+
+    Vector must have the same type as Result Type and is the vector that the
+    non-written components are copied from.
+
+    Component is the value supplied for the component selected by Index. It
+    must have the same type as the type of components in Result Type.
+
+    Index must be a scalar integer. It is interpreted as a 0-based index of
+    which component to modify.
+
+    Behavior is undefined if Index's value is less than zero or greater than
+    or equal to the number of components in Vector.
+
+    <!-- End of AutoGen section -->
+
+    ```
+    scalar-type ::= integer-type | float-type | boolean-type 
+    vector-insert-dynamic-op ::= `spv.VectorInsertDynamic ` ssa-use `,`
+                                  ssa-use `[` ssa-use `]`
+                                  `:` `vector<` integer-literal `x` scalar-type `>` `,`
+                                  integer-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %scalar = ... : f32
+    %2 = spv.VectorInsertDynamic %scalar %0[%1] : f32, vector<8xf32>, i32
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_Vector:$vector,
+    SPV_Scalar:$component,
+    SPV_Integer:$index
+  );
+
+  let results = (outs
+    SPV_Vector:$result
+  );
+  
+  let verifier = [{ return success(); }];
+  
+  let assemblyFormat = [{
+    $component `,` $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) 
+  }];
+}
+
+// -----
+
 #endif // SPIRV_COMPOSITE_OPS

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 8446f42afde3..220fa62f9e9a 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -97,14 +97,32 @@ struct VectorExtractElementOpConvert final
   }
 };
 
+struct VectorInsertElementOpConvert final
+    : public SPIRVOpLowering<vector::InsertElementOp> {
+  using SPIRVOpLowering<vector::InsertElementOp>::SPIRVOpLowering;
+  LogicalResult
+  matchAndRewrite(vector::InsertElementOp insertElementOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType()))
+      return failure();
+    vector::InsertElementOp::Adaptor adaptor(operands);
+    Value newInsertElement = rewriter.create<spirv::VectorInsertDynamicOp>(
+        insertElementOp.getLoc(), insertElementOp.getType(),
+        insertElementOp.dest(), adaptor.source(), insertElementOp.position());
+    rewriter.replaceOp(insertElementOp, newInsertElement);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
                                          SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
-                  VectorInsertOpConvert, VectorExtractElementOpConvert>(
-      context, typeConverter);
+                  VectorInsertOpConvert, VectorExtractElementOpConvert,
+                  VectorInsertElementOpConvert>(context, typeConverter);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 5ff2cc55e6d0..3594a6db805e 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -39,3 +39,21 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
   %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
   spv.ReturnValue %0: f32
 }
+
+// -----
+
+// CHECK-LABEL: insert_element
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
+//       CHECK:   spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) {
+  %0 = vector.insertelement %val, %arg0[%id : i32] : vector<4xf32>
+  spv.ReturnValue %0: vector<4xf32>
+}
+
+// -----
+
+func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
+// expected-error @+1 {{failed to legalize operation 'vector.insertelement'}}
+  %0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
+  spv.Return
+}

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
index 84b3fec1e0e5..468d5419081a 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
@@ -16,4 +16,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
     spv.ReturnValue %0: f32
   }
+  spv.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> "None" {
+    // CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
+    %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
+    spv.ReturnValue %0: vector<4xf32>
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir
index 5d969737ce76..77d091fe1107 100644
--- a/mlir/test/Dialect/SPIRV/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir
@@ -273,3 +273,13 @@ func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 {
   %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
   return %0 : f32
 }
+
+//===----------------------------------------------------------------------===//
+// spv.VectorInsertDynamic
+//===----------------------------------------------------------------------===//
+
+func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> {
+  // CHECK: spv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
+  %0 = spv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
+  return %0 : vector<4xf32>
+}


        


More information about the Mlir-commits mailing list