[Mlir-commits] [mlir] 9f622b3 - [mlir][spirv] Add more vector conversion patterns

Lei Zhang llvmlistbot at llvm.org
Fri Feb 5 06:11:24 PST 2021


Author: Lei Zhang
Date: 2021-02-05T09:11:16-05:00
New Revision: 9f622b3d5d6aa21ca28a3a06a27434cedec98fc9

URL: https://github.com/llvm/llvm-project/commit/9f622b3d5d6aa21ca28a3a06a27434cedec98fc9
DIFF: https://github.com/llvm/llvm-project/commit/9f622b3d5d6aa21ca28a3a06a27434cedec98fc9.diff

LOG: [mlir][spirv] Add more vector conversion patterns

This patch introduces a few more straightforward patterns
to convert vector ops operating on 1-4 element vectors
to their corresponding SPIR-V counterparts.

This patch also enables converting vector<1xT> to T.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D96042

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
    mlir/test/Conversion/VectorToSPIRV/simple.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 52a35a17869f..93221b06db37 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,10 +19,40 @@
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
 
 using namespace mlir;
 
+/// Gets the first integer value from `attr`, assuming it is an integer array
+/// attribute.
+static uint64_t getFirstIntValue(ArrayAttr attr) {
+  return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
+};
+
 namespace {
+
+struct VectorBitcastConvert final
+    : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
+    if (!dstType)
+      return failure();
+
+    vector::BitCastOp::Adaptor adaptor(operands);
+    if (dstType == adaptor.source().getType())
+      rewriter.replaceOp(bitcastOp, adaptor.source());
+    else
+      rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+                                                    adaptor.source());
+
+    return success();
+  }
+};
+
 struct VectorBroadcastConvert final
     : public OpConversionPattern<vector::BroadcastOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -49,17 +79,58 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.getType().isa<VectorType>() ||
-        !spirv::CompositeType::isValid(extractOp.getVectorType()))
+    // Only support extracting a scalar value now.
+    VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
+    if (resultVectorType && resultVectorType.getNumElements() > 1)
+      return failure();
+
+    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    if (!dstType)
       return failure();
+
     vector::ExtractOp::Adaptor adaptor(operands);
-    int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
+    int32_t id = getFirstIntValue(extractOp.position());
     rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
         extractOp, adaptor.vector(), id);
     return success();
   }
 };
 
+struct VectorExtractStridedSliceOpConvert final
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto dstType = getTypeConverter()->convertType(extractOp.getType());
+    if (!dstType)
+      return failure();
+
+    // Extract vector<1xT> not supported yet.
+    if (dstType.isa<spirv::ScalarType>())
+      return failure();
+
+    uint64_t offset = getFirstIntValue(extractOp.offsets());
+    uint64_t size = getFirstIntValue(extractOp.sizes());
+    uint64_t stride = getFirstIntValue(extractOp.strides());
+    if (stride != 1)
+      return failure();
+
+    Value srcVector = operands.front();
+
+    SmallVector<int32_t, 2> indices(size);
+    std::iota(indices.begin(), indices.end(), offset);
+
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        extractOp, dstType, srcVector, srcVector,
+        rewriter.getI32ArrayAttr(indices));
+
+    return success();
+  }
+};
+
 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -86,7 +157,7 @@ struct VectorInsertOpConvert final
         !spirv::CompositeType::isValid(insertOp.getDestVectorType()))
       return failure();
     vector::InsertOp::Adaptor adaptor(operands);
-    int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
+    int32_t id = getFirstIntValue(insertOp.position());
     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
         insertOp, adaptor.source(), adaptor.dest(), id);
     return success();
@@ -129,13 +200,53 @@ struct VectorInsertElementOpConvert final
   }
 };
 
+struct VectorInsertStridedSliceOpConvert final
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+                  ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value srcVector = operands.front();
+    Value dstVector = operands.back();
+
+    // Insert scalar values not supported yet.
+    if (srcVector.getType().isa<spirv::ScalarType>() ||
+        dstVector.getType().isa<spirv::ScalarType>())
+      return failure();
+
+    uint64_t stride = getFirstIntValue(insertOp.strides());
+    if (stride != 1)
+      return failure();
+
+    uint64_t totalSize =
+        dstVector.getType().cast<VectorType>().getNumElements();
+    uint64_t insertSize =
+        srcVector.getType().cast<VectorType>().getNumElements();
+    uint64_t offset = getFirstIntValue(insertOp.offsets());
+
+    SmallVector<int32_t, 2> indices(totalSize);
+    std::iota(indices.begin(), indices.end(), 0);
+    std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
+              totalSize);
+
+    rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+        insertOp, dstVector.getType(), dstVector, srcVector,
+        rewriter.getI32ArrayAttr(indices));
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
                                          SPIRVTypeConverter &typeConverter,
                                          OwningRewritePatternList &patterns) {
-  patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
-                  VectorExtractOpConvert, VectorFmaOpConvert,
-                  VectorInsertOpConvert, VectorInsertElementOpConvert>(
-      typeConverter, context);
+  patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
+                  VectorExtractElementOpConvert, VectorExtractOpConvert,
+                  VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
+                  VectorInsertElementOpConvert, VectorInsertOpConvert,
+                  VectorInsertStridedSliceOpConvert>(typeConverter, context);
 }

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1c0445290402..ff8c2b79c3de 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -269,12 +269,13 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
 static Optional<Type>
 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
                   Optional<spirv::StorageClass> storageClass = {}) {
+  if (type.getRank() == 1 && type.getNumElements() == 1)
+    return type.getElementType();
+
   if (!spirv::CompositeType::isValid(type)) {
-    // TODO: One-element vector types can be translated into scalar
-    // types. Vector types with more than four elements can be translated into
+    // TODO: Vector types with more than four elements can be translated into
     // array types.
-    LLVM_DEBUG(llvm::dbgs()
-               << type << " illegal: 1- and > 4-element unimplemented\n");
+    LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
     return llvm::None;
   }
 

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 8ae93c2e4b9b..f190a0e2fbb2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -117,9 +117,9 @@ func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) {
   return
 }
 
-// CHECK-LABEL: @unsupported_1elem_vector
-func @unsupported_1elem_vector(%arg0: vector<1xi32>) {
-  // CHECK: addi
+// CHECK-LABEL: @one_elem_vector
+func @one_elem_vector(%arg0: vector<1xi32>) {
+  // CHECK: spv.IAdd %{{.+}}, %{{.+}}: i32
   %0 = addi %arg0, %arg0: vector<1xi32>
   return
 }

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 558d0f3999d4..b511c476475c 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -203,18 +203,19 @@ func @float_vector(
   %arg1: vector<3xf64>
 ) { return }
 
+// CHECK-LABEL: spv.func @one_element_vector
+// CHECK-SAME: %{{.+}}: i32
+func @one_element_vector(%arg0: vector<1xi32>) { return }
+
 } // end module
 
 // -----
 
-// Check that 1- or > 4-element vectors are not supported.
+// Check that > 4-element vectors are not supported.
 module attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
 } {
 
-// CHECK-NOT: spv.func @one_element_vector
-func @one_element_vector(%arg0: vector<1xi32>) { return }
-
 // CHECK-NOT: spv.func @large_vector
 func @large_vector(%arg0: vector<1024xi32>) { return }
 

diff  --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index fddfd911fb19..836d3853d335 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -1,5 +1,21 @@
 // RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
+
+// CHECK-LABEL: func @bitcast
+//  CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16>
+//       CHECK:   %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
+//       CHECK:   %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
+func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) {
+  %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+  %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+  spv.Return
+}
+
+} // end module
+
+// -----
+
 // CHECK-LABEL: broadcast
 //  CHECK-SAME: %[[A:.*]]: f32
 //       CHECK:   spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
@@ -12,6 +28,18 @@ func @broadcast(%arg0 : f32) {
 
 // -----
 
+// CHECK-LABEL: func @extract
+//  CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
+//       CHECK:   %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
+//       CHECK:   %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
+func @extract(%arg0 : vector<2xf32>) {
+  %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
+  %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
+  spv.Return
+}
+
+// -----
+
 // CHECK-LABEL: extract_insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>
 //       CHECK:   %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -42,6 +70,16 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_slice
+//  CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
+//       CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+func @extract_strided_slice(%arg0: vector<4xf32>) {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+  spv.Return
+}
+
+// -----
+
 // CHECK-LABEL: insert_element
 //  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
@@ -60,6 +98,16 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
 
 // -----
 
+// CHECK-LABEL: func @insert_strided_slice
+//  CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
+//       CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) {
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
+  spv.Return
+}
+
+// -----
+
 // CHECK-LABEL: func @fma
 //  CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
 //       CHECK:   spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>


        


More information about the Mlir-commits mailing list