[Mlir-commits] [mlir] d137c05 - [mlir][spirv] Add conversion from vector.reduction
Lei Zhang
llvmlistbot at llvm.org
Wed Apr 27 07:29:53 PDT 2022
Author: Lei Zhang
Date: 2022-04-27T10:29:46-04:00
New Revision: d137c05fc9a3aa04c66daaa7a92b90998676f693
URL: https://github.com/llvm/llvm-project/commit/d137c05fc9a3aa04c66daaa7a92b90998676f693
DIFF: https://github.com/llvm/llvm-project/commit/d137c05fc9a3aa04c66daaa7a92b90998676f693.diff
LOG: [mlir][spirv] Add conversion from vector.reduction
Only supports addition and multiplication for now; other cases
to be implemented.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D124380
Added:
mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Modified:
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Removed:
mlir/test/Conversion/VectorToSPIRV/simple.mlir
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 5bdcc384165c9..e404d75f7d243 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -19,7 +19,9 @@
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include <numeric>
@@ -249,6 +251,69 @@ struct VectorInsertStridedSliceOpConvert final
}
};
+struct VectorReductionPattern final
+ : public OpConversionPattern<vector::ReductionOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type resultType = typeConverter->convertType(reduceOp.getType());
+ if (!resultType)
+ return failure();
+
+ auto srcVectorType = adaptor.getVector().getType().dyn_cast<VectorType>();
+ if (!srcVectorType || srcVectorType.getRank() != 1)
+ return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source");
+
+ // Extract all elements.
+ int numElements = srcVectorType.getDimSize(0);
+ SmallVector<Value, 4> values;
+ values.reserve(numElements + (adaptor.getAcc() != nullptr));
+ Location loc = reduceOp.getLoc();
+ for (int i = 0; i < numElements; ++i) {
+ values.push_back(rewriter.create<spirv::CompositeExtractOp>(
+ loc, srcVectorType.getElementType(), adaptor.getVector(),
+ rewriter.getI32ArrayAttr({i})));
+ }
+ if (Value acc = adaptor.getAcc())
+ values.push_back(acc);
+
+ // Reduce them.
+ Value result = values.front();
+ for (Value next : llvm::makeArrayRef(values).drop_front()) {
+ switch (reduceOp.getKind()) {
+#define INT_FLOAT_CASE(kind, iop, fop) \
+ case vector::CombiningKind::kind: \
+ if (resultType.isa<IntegerType>()) { \
+ result = rewriter.create<spirv::iop>(loc, resultType, result, next); \
+ } else { \
+ assert(resultType.isa<FloatType>()); \
+ result = rewriter.create<spirv::fop>(loc, resultType, result, next); \
+ } \
+ break
+
+ INT_FLOAT_CASE(ADD, IAddOp, FAddOp);
+ INT_FLOAT_CASE(MUL, IMulOp, FMulOp);
+
+ case vector::CombiningKind::MINUI:
+ case vector::CombiningKind::MINSI:
+ case vector::CombiningKind::MINF:
+ case vector::CombiningKind::MAXUI:
+ case vector::CombiningKind::MAXSI:
+ case vector::CombiningKind::MAXF:
+ case vector::CombiningKind::AND:
+ case vector::CombiningKind::OR:
+ case vector::CombiningKind::XOR:
+ return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
+ }
+ }
+
+ rewriter.replaceOp(reduceOp, result);
+ return success();
+ }
+};
+
class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
public:
using OpConversionPattern<vector::SplatOp>::OpConversionPattern;
@@ -312,6 +377,7 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern>(typeConverter, patterns.getContext());
+ VectorReductionPattern, VectorInsertStridedSliceOpConvert,
+ VectorShuffleOpConvert, VectorSplatPattern>(
+ typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
similarity index 79%
rename from mlir/test/Conversion/VectorToSPIRV/simple.mlir
rename to mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 421c803b5a91a..8cd4258e83d2f 100644
--- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -46,7 +46,7 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
// CHECK: return %[[R]]
func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
%0 = vector.extract %arg0[0] : vector<1xf32>
- return %0: f32
+ return %0: f32
}
// -----
@@ -56,7 +56,7 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
// CHECK: spv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
%1 = vector.insert %arg1, %arg0[2] : f32 into vector<4xf32>
- return %1: vector<4xf32>
+ return %1: vector<4xf32>
}
// -----
@@ -67,7 +67,7 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
// CHECK: return %[[R]]
func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> {
%1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32>
- return %1 : vector<1xf32>
+ return %1 : vector<1xf32>
}
// -----
@@ -84,7 +84,7 @@ func.func @extract_element(%arg0 : vector<4xf32>, %id : i32) -> f32 {
// CHECK-LABEL: @extract_element_index
func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
- // CHECK: vector.extractelement
+ // CHECK: vector.extractelement
%0 = vector.extractelement %arg0[%id : index] : vector<4xf32>
return %0: f32
}
@@ -93,7 +93,7 @@ func.func @extract_element_index(%arg0 : vector<4xf32>, %id : index) -> f32 {
// CHECK-LABEL: @extract_element_size5_vector
func.func @extract_element_size5_vector(%arg0 : vector<5xf32>, %id : i32) -> f32 {
- // CHECK: vector.extractelement
+ // CHECK: vector.extractelement
%0 = vector.extractelement %arg0[%id : i32] : vector<5xf32>
return %0: f32
}
@@ -124,7 +124,7 @@ func.func @insert_element(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector
// CHECK-LABEL: @insert_element_index
func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
- // CHECK: vector.insertelement
+ // CHECK: vector.insertelement
%0 = vector.insertelement %val, %arg0[%id : index] : vector<4xf32>
return %0: vector<4xf32>
}
@@ -133,19 +133,19 @@ func.func @insert_element_index(%val: f32, %arg0 : vector<4xf32>, %id : index) -
// CHECK-LABEL: @insert_element_size5_vector
func.func @insert_element_size5_vector(%val: f32, %arg0 : vector<5xf32>, %id : i32) -> vector<5xf32> {
- // CHECK: vector.insertelement
+ // CHECK: vector.insertelement
%0 = vector.insertelement %val, %arg0[%id : i32] : vector<5xf32>
- return %0 : vector<5xf32>
+ return %0 : vector<5xf32>
}
// -----
// CHECK-LABEL: @insert_strided_slice
// CHECK-SAME: %[[PART:.+]]: vector<2xf32>, %[[ALL:.+]]: vector<4xf32>
-// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
+// CHECK: spv.VectorShuffle [0 : i32, 4 : i32, 5 : i32, 3 : i32] %[[ALL]] : vector<4xf32>, %[[PART]] : vector<2xf32> -> vector<4xf32>
func.func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1], strides = [1]} : vector<2xf32> into vector<4xf32>
- return %0 : vector<4xf32>
+ return %0 : vector<4xf32>
}
// -----
@@ -156,7 +156,7 @@ func.func @insert_strided_slice(%arg0: vector<2xf32>, %arg1: vector<4xf32>) -> v
// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32>
func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> {
%1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
- return %1 : vector<3xf32>
+ return %1 : vector<3xf32>
}
// -----
@@ -166,7 +166,7 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> v
// CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>
func.func @fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
%0 = vector.fma %a, %b, %c: vector<4xf32>
- return %0 : vector<4xf32>
+ return %0 : vector<4xf32>
}
// -----
@@ -210,3 +210,36 @@ func.func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16
%shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32>
return %shuffle : vector<3x16xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @reduction
+// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<4xi32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<4xi32>
+// CHECK: %[[S3:.+]] = spv.CompositeExtract %[[V]][3 : i32] : vector<4xi32>
+// CHECK: %[[ADD0:.+]] = spv.IAdd %[[S0]], %[[S1]]
+// CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[S2]]
+// CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[S3]]
+// CHECK: return %[[ADD2]]
+func.func @reduction(%v : vector<4xi32>) -> i32 {
+ %reduce = vector.reduction <add>, %v : vector<4xi32> into i32
+ return %reduce : i32
+}
+
+// -----
+
+// CHECK-LABEL: func @reduction
+// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
+// CHECK: %[[S0:.+]] = spv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
+// CHECK: %[[S1:.+]] = spv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
+// CHECK: %[[S2:.+]] = spv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
+// CHECK: %[[ADD0:.+]] = spv.FMul %[[S0]], %[[S1]]
+// CHECK: %[[ADD1:.+]] = spv.FMul %[[ADD0]], %[[S2]]
+// CHECK: %[[ADD2:.+]] = spv.FMul %[[ADD1]], %[[S]]
+// CHECK: return %[[ADD2]]
+func.func @reduction(%v : vector<3xf32>, %s: f32) -> f32 {
+ %reduce = vector.reduction <mul>, %v, %s : vector<3xf32> into f32
+ return %reduce : f32
+}
More information about the Mlir-commits
mailing list