[Mlir-commits] [mlir] 071358e - [mlir][Linalg] Add producer-consumer fusion when producer is a ConstantOp

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 20 09:16:51 PDT 2020


Author: MaheshRavishankar
Date: 2020-05-20T09:16:19-07:00
New Revision: 071358e08224b9971f6b7fc49a5e014a5662187c

URL: https://github.com/llvm/llvm-project/commit/071358e08224b9971f6b7fc49a5e014a5662187c
DIFF: https://github.com/llvm/llvm-project/commit/071358e08224b9971f6b7fc49a5e014a5662187c.diff

LOG: [mlir][Linalg] Add producer-consumer fusion when producer is a ConstantOp
and Consumer is a GenericOp.

Differential Revision: https://reviews.llvm.org/D79838

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index fbeac83d4305..3123f95452fd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -766,6 +766,66 @@ template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer {
     return fusedOp;
   }
 };
+
+/// Implementation of fusion on tensor ops when producer is a splat constant.
+template <typename LinalgOpTy> struct FuseConstantOpAsProducer {
+  static bool isFusible(ConstantOp producer, LinalgOpTy consumer,
+                        unsigned consumerIdx) {
+    return producer.getResult().getType().isa<RankedTensorType>() &&
+           producer.value().template cast<DenseElementsAttr>().isSplat();
+  }
+
+  static Operation *fuse(ConstantOp producer, LinalgOpTy consumer,
+                         unsigned consumerIdx, PatternRewriter &rewriter,
+                         OperationFolder *folder = nullptr) {
+    if (!isFusible(producer, consumer, consumerIdx))
+      return nullptr;
+
+    // The indexing_maps for the operands of the fused operation are same as
+    // those for the operands of the consumer without the indexing map at
+    // consumerIdx
+    SmallVector<AffineMap, 4> fusedIndexMaps =
+        llvm::to_vector<4>(llvm::map_range(
+            consumer.indexing_maps(), [](Attribute attr) -> AffineMap {
+              return attr.cast<AffineMapAttr>().getValue();
+            }));
+    fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx));
+
+    // The operands list is same as the consumer with the argument for constant
+    // index dropped.
+    SmallVector<Value, 4> fusedOperands(consumer.operand_begin(),
+                                        consumer.operand_end());
+    fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx));
+
+    // Create a constant scalar value from the splat constant.
+    Value scalarConstant = rewriter.create<ConstantOp>(
+        producer.getLoc(),
+        producer.value().template cast<DenseElementsAttr>().getSplatValue());
+
+    auto fusedOp = rewriter.create<LinalgOpTy>(
+        rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands,
+        rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1),
+        rewriter.getI64IntegerAttr(consumer.getNumResults()),
+        rewriter.getAffineMapArrayAttr(fusedIndexMaps),
+        consumer.iterator_types(),
+        /*doc=*/nullptr,
+        /*library_call=*/nullptr);
+
+    // Map the block argument corresponding to the replaced argument with the
+    // scalar constant.
+    Region &consumerRegion = consumer.region();
+    Block &entryBlock = *consumerRegion.begin();
+    unsigned argIndex =
+        entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx;
+    BlockAndValueMapping mapping;
+    mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
+    Region &fusedRegion = fusedOp.region();
+    rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(),
+                               mapping);
+    return fusedOp;
+  }
+};
+
 } // namespace
 
 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
@@ -789,6 +849,9 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
     } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) {
       return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
           reshapeOpProducer, genericOp, consumerIdx, rewriter, folder);
+    } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
+      return FuseConstantOpAsProducer<GenericOp>::fuse(
+          constantOpProducer, genericOp, consumerIdx, rewriter, folder);
     }
     return nullptr;
   }

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 2c00f77edd3f..83bd1753eb28 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -219,3 +219,58 @@ func @generic_op_reshape_consumer_nofusion(%arg0 : tensor<?x?x?x5xf32>,
 
 // CHECK-LABEL: func @generic_op_reshape_consumer_nofusion
 //       CHECK: linalg.tensor_reshape
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_constant_fusion(%arg0 : tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+{
+  %0 = constant dense<42.0> : tensor<5xf32>
+  %1 = linalg.generic
+       {args_in = 2 : i64, args_out = 1 : i64,
+         indexing_maps = [#map0, #map1, #map1],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+       %0, %arg0 {
+       ^bb0(%arg1: f32, %arg2: f32):
+         %2 = mulf %arg1, %arg2 : f32
+         linalg.yield %2 : f32
+       }: tensor<5xf32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+  return %1 : tensor<5x?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @generic_op_constant_fusion
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     args_in = 1 : i64
+//  CHECK-SAME:     args_out = 1 : i64
+//       CHECK:   ^{{.*}}(%[[ARG1:.*]]: f32)
+//       CHECK:     mulf %[[CST]], %[[ARG1]]
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> ()>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_zero_dim_constant_fusion(%arg0 : tensor<5x?x?xf32>)
+  -> tensor<5x?x?xf32>
+{
+  %0 = constant dense<42.0> : tensor<f32>
+  %1 = linalg.generic
+       {args_in = 2 : i64, args_out = 1 : i64,
+         indexing_maps = [#map0, #map1, #map1],
+         iterator_types = ["parallel", "parallel", "parallel"]}
+       %0, %arg0 {
+       ^bb0(%arg1: f32, %arg2: f32):
+         %2 = mulf %arg1, %arg2 : f32
+         linalg.yield %2 : f32
+       }: tensor<f32>, tensor<5x?x?xf32> -> tensor<5x?x?xf32>
+  return %1 : tensor<5x?x?xf32>
+}
+//   CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @generic_op_zero_dim_constant_fusion
+//       CHECK:   %[[CST:.*]] = constant {{.*}} : f32
+//       CHECK:   linalg.generic
+//  CHECK-SAME:     args_in = 1 : i64
+//  CHECK-SAME:     args_out = 1 : i64
+//       CHECK:   ^{{.*}}(%[[ARG1:.*]]: f32)
+//       CHECK:     mulf %[[CST]], %[[ARG1]]


        


More information about the Mlir-commits mailing list