[Mlir-commits] [mlir] d51275c - [mlir][spirv] Add support to convert std.splat op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 3 10:58:07 PDT 2021
Author: thomasraoux
Date: 2021-05-03T10:57:40-07:00
New Revision: d51275cbc071e318df94d7f5a469b17b8960dab0
URL: https://github.com/llvm/llvm-project/commit/d51275cbc071e318df94d7f5a469b17b8960dab0
DIFF: https://github.com/llvm/llvm-project/commit/d51275cbc071e318df94d7f5a469b17b8960dab0.diff
LOG: [mlir][spirv] Add support to convert std.splat op
Differential Revision: https://reviews.llvm.org/D101511
Added:
Modified:
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 2a6e7f2818602..3851bacc1f843 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -470,6 +470,16 @@ class SelectOpPattern final : public OpConversionPattern<SelectOp> {
ConversionPatternRewriter &rewriter) const override;
};
+/// Converts std.splat to spv.CompositeConstruct.
+class SplatPattern final : public OpConversionPattern<SplatOp> {
+public:
+ using OpConversionPattern<SplatOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Converts memref.store to spv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
@@ -1127,6 +1137,23 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
return success();
}
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+SplatPattern::matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto dstVecType = op.getType().dyn_cast<VectorType>();
+ if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
+ return failure();
+ SplatOp::Adaptor adaptor(operands);
+ SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
+ source);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//
@@ -1332,7 +1359,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
LoadOpPattern, StoreOpPattern,
- ReturnOpPattern, SelectOpPattern,
+ ReturnOpPattern, SelectOpPattern, SplatPattern,
// Type cast patterns
UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index fe769482c787b..c9ce74b92c0b5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1249,3 +1249,18 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
// CHECK: spv.ReturnValue %[[VAL]]
return %extract : i32
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// splat
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @splat
+// CHECK-SAME: (%[[A:.+]]: f32)
+// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+// CHECK: spv.ReturnValue %[[VAL]]
+func @splat(%f : f32) -> vector<4xf32> {
+ %splat = splat %f : vector<4xf32>
+ return %splat : vector<4xf32>
+}
More information about the Mlir-commits
mailing list