[Mlir-commits] [mlir] 9f622b3 - [mlir][spirv] Add more vector conversion patterns
Lei Zhang
llvmlistbot at llvm.org
Fri Feb 5 06:11:24 PST 2021
Author: Lei Zhang
Date: 2021-02-05T09:11:16-05:00
New Revision: 9f622b3d5d6aa21ca28a3a06a27434cedec98fc9
URL: https://github.com/llvm/llvm-project/commit/9f622b3d5d6aa21ca28a3a06a27434cedec98fc9
DIFF: https://github.com/llvm/llvm-project/commit/9f622b3d5d6aa21ca28a3a06a27434cedec98fc9.diff
LOG: [mlir][spirv] Add more vector conversion patterns
This patch introduces a few more straightforward patterns
to convert vector ops operating on 1-4 element vectors
to their corresponding SPIR-V counterparts.
This patch also enables converting vector<1xT> to T.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D96042
Added:
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
mlir/test/Conversion/VectorToSPIRV/simple.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 52a35a17869f..93221b06db37 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,10 +19,40 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <numeric>
using namespace mlir;
+/// Gets the first integer value from `attr`, assuming it is an integer array
+/// attribute.
+static uint64_t getFirstIntValue(ArrayAttr attr) {
+ return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
+};
+
namespace {
+
+struct VectorBitcastConvert final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
+ if (!dstType)
+ return failure();
+
+ vector::BitCastOp::Adaptor adaptor(operands);
+ if (dstType == adaptor.source().getType())
+ rewriter.replaceOp(bitcastOp, adaptor.source());
+ else
+ rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
+ adaptor.source());
+
+ return success();
+ }
+};
+
struct VectorBroadcastConvert final
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
@@ -49,17 +79,58 @@ struct VectorExtractOpConvert final
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- if (extractOp.getType().isa<VectorType>() ||
- !spirv::CompositeType::isValid(extractOp.getVectorType()))
+ // Only support extracting a scalar value now.
+ VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>();
+ if (resultVectorType && resultVectorType.getNumElements() > 1)
+ return failure();
+
+ auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ if (!dstType)
return failure();
+
vector::ExtractOp::Adaptor adaptor(operands);
- int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
+ int32_t id = getFirstIntValue(extractOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, adaptor.vector(), id);
return success();
}
};
+struct VectorExtractStridedSliceOpConvert final
+ : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto dstType = getTypeConverter()->convertType(extractOp.getType());
+ if (!dstType)
+ return failure();
+
+ // Extract vector<1xT> not supported yet.
+ if (dstType.isa<spirv::ScalarType>())
+ return failure();
+
+ uint64_t offset = getFirstIntValue(extractOp.offsets());
+ uint64_t size = getFirstIntValue(extractOp.sizes());
+ uint64_t stride = getFirstIntValue(extractOp.strides());
+ if (stride != 1)
+ return failure();
+
+ Value srcVector = operands.front();
+
+ SmallVector<int32_t, 2> indices(size);
+ std::iota(indices.begin(), indices.end(), offset);
+
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ extractOp, dstType, srcVector, srcVector,
+ rewriter.getI32ArrayAttr(indices));
+
+ return success();
+ }
+};
+
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
@@ -86,7 +157,7 @@ struct VectorInsertOpConvert final
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
vector::InsertOp::Adaptor adaptor(operands);
- int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
+ int32_t id = getFirstIntValue(insertOp.position());
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.source(), adaptor.dest(), id);
return success();
@@ -129,13 +200,53 @@ struct VectorInsertElementOpConvert final
}
};
+struct VectorInsertStridedSliceOpConvert final
+ : public OpConversionPattern<vector::InsertStridedSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+ ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Value srcVector = operands.front();
+ Value dstVector = operands.back();
+
+ // Insert scalar values not supported yet.
+ if (srcVector.getType().isa<spirv::ScalarType>() ||
+ dstVector.getType().isa<spirv::ScalarType>())
+ return failure();
+
+ uint64_t stride = getFirstIntValue(insertOp.strides());
+ if (stride != 1)
+ return failure();
+
+ uint64_t totalSize =
+ dstVector.getType().cast<VectorType>().getNumElements();
+ uint64_t insertSize =
+ srcVector.getType().cast<VectorType>().getNumElements();
+ uint64_t offset = getFirstIntValue(insertOp.offsets());
+
+ SmallVector<int32_t, 2> indices(totalSize);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
+ totalSize);
+
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ insertOp, dstVector.getType(), dstVector, srcVector,
+ rewriter.getI32ArrayAttr(indices));
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<VectorBroadcastConvert, VectorExtractElementOpConvert,
- VectorExtractOpConvert, VectorFmaOpConvert,
- VectorInsertOpConvert, VectorInsertElementOpConvert>(
- typeConverter, context);
+ patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
+ VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
+ VectorInsertElementOpConvert, VectorInsertOpConvert,
+ VectorInsertStridedSliceOpConvert>(typeConverter, context);
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 1c0445290402..ff8c2b79c3de 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -269,12 +269,13 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
static Optional<Type>
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
Optional<spirv::StorageClass> storageClass = {}) {
+ if (type.getRank() == 1 && type.getNumElements() == 1)
+ return type.getElementType();
+
if (!spirv::CompositeType::isValid(type)) {
- // TODO: One-element vector types can be translated into scalar
- // types. Vector types with more than four elements can be translated into
+ // TODO: Vector types with more than four elements can be translated into
// array types.
- LLVM_DEBUG(llvm::dbgs()
- << type << " illegal: 1- and > 4-element unimplemented\n");
+ LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
return llvm::None;
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 8ae93c2e4b9b..f190a0e2fbb2 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -117,9 +117,9 @@ func @float_vector234(%arg0: vector<2xf16>, %arg1: vector<3xf64>) {
return
}
-// CHECK-LABEL: @unsupported_1elem_vector
-func @unsupported_1elem_vector(%arg0: vector<1xi32>) {
- // CHECK: addi
+// CHECK-LABEL: @one_elem_vector
+func @one_elem_vector(%arg0: vector<1xi32>) {
+ // CHECK: spv.IAdd %{{.+}}, %{{.+}}: i32
%0 = addi %arg0, %arg0: vector<1xi32>
return
}
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index 558d0f3999d4..b511c476475c 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -203,18 +203,19 @@ func @float_vector(
%arg1: vector<3xf64>
) { return }
+// CHECK-LABEL: spv.func @one_element_vector
+// CHECK-SAME: %{{.+}}: i32
+func @one_element_vector(%arg0: vector<1xi32>) { return }
+
} // end module
// -----
-// Check that 1- or > 4-element vectors are not supported.
+// Check that > 4-element vectors are not supported.
module attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
} {
-// CHECK-NOT: spv.func @one_element_vector
-func @one_element_vector(%arg0: vector<1xi32>) { return }
-
// CHECK-NOT: spv.func @large_vector
func @large_vector(%arg0: vector<1024xi32>) { return }
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
index fddfd911fb19..836d3853d335 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir
@@ -1,5 +1,21 @@
// RUN: mlir-opt -split-input-file -convert-vector-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+module attributes { spv.target_env = #spv.target_env<#spv.vce<v1.0, [Float16], []>, {}> } {
+
+// CHECK-LABEL: func @bitcast
+// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf16>
+// CHECK: %{{.+}} = spv.Bitcast %[[ARG0]] : vector<2xf32> to vector<4xf16>
+// CHECK: %{{.+}} = spv.Bitcast %[[ARG1]] : vector<2xf16> to f32
+func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) {
+ %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16>
+ %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32>
+ spv.Return
+}
+
+} // end module
+
+// -----
+
// CHECK-LABEL: broadcast
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
@@ -12,6 +28,18 @@ func @broadcast(%arg0 : f32) {
// -----
+// CHECK-LABEL: func @extract
+// CHECK-SAME: %[[ARG:.+]]: vector<2xf32>
+// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][0 : i32] : vector<2xf32>
+// CHECK: %{{.+}} = spv.CompositeExtract %[[ARG]][1 : i32] : vector<2xf32>
+func @extract(%arg0 : vector<2xf32>) {
+ %0 = "vector.extract"(%arg0) {position = [0]} : (vector<2xf32>) -> vector<1xf32>
+ %1 = "vector.extract"(%arg0) {position = [1]} : (vector<2xf32>) -> f32
+ spv.Return
+}
+
+// -----
+
// CHECK-LABEL: extract_insert
// CHECK-SAME: %[[V:.*]]: vector<4xf32>
// CHECK: %[[S:.*]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
@@ -42,6 +70,16 @@ func @extract_element_negative(%arg0 : vector<5xf32>, %id : i32) {
// -----
+// CHECK-LABEL: func @extract_strided_slice
+// CHECK-SAME: %[[ARG:.+]]: vector<4xf32>
+// CHECK: %{{.+}} = spv.VectorShuffle [1 : i32, 2 : i32] %[[ARG]] : vector<4xf32>, %[[ARG]] : vector<4xf32> -> vector<2xf32>
+func @extract_strided_slice(%arg0: vector<4xf32>) {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
+ spv.Return
+}
+
+// -----
+
// CHECK-LABEL: insert_element
// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
// CHECK: spv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
@@ -60,6 +98,16 @@ func @insert_element_negative(%val: f32, %arg0 : vector<5xf32>, %id : i32) {
// -----
+// CHECK-LABEL: func @insert_strided_slice
+// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
+// CHECK: %{{.+}} = spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) {
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
+ spv.Return
+}
+
+// -----
+
// CHECK-LABEL: func @fma
// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
More information about the Mlir-commits
mailing list