[Mlir-commits] [mlir] d51275c - [mlir][spirv] Add support to convert std.splat op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 3 10:58:07 PDT 2021


Author: thomasraoux
Date: 2021-05-03T10:57:40-07:00
New Revision: d51275cbc071e318df94d7f5a469b17b8960dab0

URL: https://github.com/llvm/llvm-project/commit/d51275cbc071e318df94d7f5a469b17b8960dab0
DIFF: https://github.com/llvm/llvm-project/commit/d51275cbc071e318df94d7f5a469b17b8960dab0.diff

LOG: [mlir][spirv] Add support to convert std.splat op

Differential Revision: https://reviews.llvm.org/D101511

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 2a6e7f2818602..3851bacc1f843 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -470,6 +470,16 @@ class SelectOpPattern final : public OpConversionPattern<SelectOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// Converts std.splat to spv.CompositeConstruct.
+class SplatPattern final : public OpConversionPattern<SplatOp> {
+public:
+  using OpConversionPattern<SplatOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Converts memref.store to spv.Store on integers.
 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
 public:
@@ -1127,6 +1137,23 @@ SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SplatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+SplatPattern::matchAndRewrite(SplatOp op, ArrayRef<Value> operands,
+                              ConversionPatternRewriter &rewriter) const {
+  auto dstVecType = op.getType().dyn_cast<VectorType>();
+  if (!dstVecType || !spirv::CompositeType::isValid(dstVecType))
+    return failure();
+  SplatOp::Adaptor adaptor(operands);
+  SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input());
+  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType,
+                                                           source);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // StoreOp
 //===----------------------------------------------------------------------===//
@@ -1332,7 +1359,7 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
       LoadOpPattern, StoreOpPattern,
 
-      ReturnOpPattern, SelectOpPattern,
+      ReturnOpPattern, SelectOpPattern, SplatPattern,
 
       // Type cast patterns
       UIToFPI1Pattern, ZeroExtendI1Pattern, TruncI1Pattern,

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index fe769482c787b..c9ce74b92c0b5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1249,3 +1249,18 @@ func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
   // CHECK: spv.ReturnValue %[[VAL]]
   return %extract : i32
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// splat
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @splat
+//  CHECK-SAME: (%[[A:.+]]: f32)
+//       CHECK:   %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32>
+//       CHECK:   spv.ReturnValue %[[VAL]]
+func @splat(%f : f32) -> vector<4xf32> {
+  %splat = splat %f : vector<4xf32>
+  return %splat : vector<4xf32>
+}


        


More information about the Mlir-commits mailing list