[Mlir-commits] [mlir] f9dca10 - [mlir][spirv] Add VectorExtractDynamicOp and vector.extractelement lowering

Artur Bialas llvmlistbot at llvm.org
Wed Nov 4 23:27:29 PST 2020


Author: Artur Bialas
Date: 2020-11-05T08:26:54+01:00
New Revision: f9dca1039a4aa06d0449666888a6bc18b2572159

URL: https://github.com/llvm/llvm-project/commit/f9dca1039a4aa06d0449666888a6bc18b2572159
DIFF: https://github.com/llvm/llvm-project/commit/f9dca1039a4aa06d0449666888a6bc18b2572159.diff

LOG: [mlir][spirv] Add VectorExtractDynamicOp and vector.extractelement lowering

VectorExtractDynamicOp in SPIRV dialect
conversion from vector.extractelement to spirv VectorExtractDynamicOp

Differential Revision: https://reviews.llvm.org/D90679

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/test/Conversion/VectorToSPIRV/simple.mlir
    mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
    mlir/test/Dialect/SPIRV/composite-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index cc23969dea5d..e7f1a3888ab3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -3176,6 +3176,7 @@ def SPV_OC_OpCopyMemory                : I32EnumAttrCase<"OpCopyMemory", 63>;
 def SPV_OC_OpAccessChain               : I32EnumAttrCase<"OpAccessChain", 65>;
 def SPV_OC_OpDecorate                  : I32EnumAttrCase<"OpDecorate", 71>;
 def SPV_OC_OpMemberDecorate            : I32EnumAttrCase<"OpMemberDecorate", 72>;
+def SPV_OC_OpVectorExtractDynamic      : I32EnumAttrCase<"OpVectorExtractDynamic", 77>;
 def SPV_OC_OpCompositeConstruct        : I32EnumAttrCase<"OpCompositeConstruct", 80>;
 def SPV_OC_OpCompositeExtract          : I32EnumAttrCase<"OpCompositeExtract", 81>;
 def SPV_OC_OpCompositeInsert           : I32EnumAttrCase<"OpCompositeInsert", 82>;
@@ -3309,8 +3310,9 @@ def SPV_OpcodeAttr :
       SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
       SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
       SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
-      SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
-      SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct,
+      SPV_OC_OpStore, SPV_OC_OpCopyMemory, SPV_OC_OpAccessChain,
+      SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate,
+      SPV_OC_OpVectorExtractDynamic, SPV_OC_OpCompositeConstruct,
       SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose,
       SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF,
       SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
index b824192c1309..c1f76d0fee20 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCompositeOps.td
@@ -168,4 +168,61 @@ def SPV_CompositeInsertOp : SPV_Op<"CompositeInsert", [NoSideEffect]> {
   ];
 }
 
+// -----
+
+def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic",
+            [NoSideEffect, TypesMatchWith<"type of 'value' matches element type of 'vector'",
+                     "vector", "result",
+                     "$_self.cast<mlir::VectorType>().getElementType()">]> {
+  let summary = [{
+    Extract a single, dynamically selected, component of a vector.
+  }];
+
+  let description = [{
+    Result Type must be a scalar type.
+
+    Vector must have a type OpTypeVector whose Component Type is Result
+    Type.
+
+    Index must be a scalar integer. It is interpreted as a 0-based index of
+    which component of Vector to extract.
+
+    Behavior is undefined if Index's value is less than zero or greater than
+    or equal to the number of components in Vector.
+
+    <!-- End of AutoGen section -->
+
+    ```
+	scalar-type ::= integer-type | float-type | boolean-type 
+    vector-extract-dynamic-op ::= `spv.VectorExtractDynamic ` ssa-use `[` ssa-use `]`
+                                    `:` `vector<` integer-literal `x` scalar-type `>` `,`
+                                    integer-type
+    ```mlir
+
+    #### Example:
+
+    ```
+    %2 = spv.VectorExtractDynamic %0[%1] : vector<8xf32>, i32
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_Vector:$vector,
+    SPV_Integer:$index
+  );
+
+  let results = (outs
+    SPV_Scalar:$result
+  );
+  
+  let verifier = [{ return success(); }];
+  
+  let assemblyFormat = [{
+    $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) 
+  }];
+  
+}
+
+// -----
+
 #endif // SPIRV_COMPOSITE_OPS

diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index ad26118dc643..8446f42afde3 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -78,13 +78,33 @@ struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
     return success();
   }
 };
+
+struct VectorExtractElementOpConvert final
+    : public SPIRVOpLowering<vector::ExtractElementOp> {
+  using SPIRVOpLowering<vector::ExtractElementOp>::SPIRVOpLowering;
+  LogicalResult
+  matchAndRewrite(vector::ExtractElementOp extractElementOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!spirv::CompositeType::isValid(extractElementOp.getVectorType()))
+      return failure();
+    vector::ExtractElementOp::Adaptor adaptor(operands);
+    Value newExtractElement = rewriter.create<spirv::VectorExtractDynamicOp>(
+        extractElementOp.getLoc(), extractElementOp.getType(), adaptor.vector(),
+        extractElementOp.position());
+    rewriter.replaceOp(extractElementOp, newExtractElement);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
                                          SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
   patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
-                  VectorInsertOpConvert>(context, typeConverter);
+                  VectorInsertOpConvert, VectorExtractElementOpConvert>(
+      context, typeConverter);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index 34f1ef52c237..5ff2cc55e6d0 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -convert-vector-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
 // CHECK-LABEL: broadcast
 //  CHECK-SAME: %[[A:.*]]: f32
@@ -21,3 +21,21 @@ func @extract_insert(%arg0 : vector<4xf32>) {
   %1 = vector.insert %0, %arg0[0] : f32 into vector<4xf32>
   spv.Return
 }
+
+// -----
+
+// CHECK-LABEL: extract_element
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
+//       CHECK:   spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func @extract_element(%arg0 : vector<4xf32>, %id : i32) {
+  %0 = vector.extractelement %arg0[%id : i32] : vector<4xf32>
+  spv.ReturnValue %0: f32
+}
+
+// -----
+
+func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
+// expected-error @+1 {{failed to legalize operation 'vector.extractelement'}}
+  %0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
+  spv.ReturnValue %0: f32
+}

diff  --git a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
index aca7a9d93539..84b3fec1e0e5 100644
--- a/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
+++ b/mlir/test/Dialect/SPIRV/Serialization/composite-op.mlir
@@ -11,4 +11,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32>
     spv.ReturnValue %0: vector<3xf32>
   }
+  spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" {
+    // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
+    %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
+    spv.ReturnValue %0: f32
+  }
 }

diff  --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir
index 18c20ba418fe..5d969737ce76 100644
--- a/mlir/test/Dialect/SPIRV/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir
@@ -261,3 +261,15 @@ func @composite_insert_invalid_result_type(%arg0: !spv.array<4xf32>, %arg1 : f32
   %0 = "spv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spv.array<4xf32>) -> !spv.array<4xf64>
   return %0: !spv.array<4xf64>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.VectorExtractDynamic
+//===----------------------------------------------------------------------===//
+
+func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 {
+  // CHECK: spv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
+  %0 = spv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
+  return %0 : f32
+}


        


More information about the Mlir-commits mailing list