[Mlir-commits] [mlir] 27fca57 - [mlir][Linalg] Add support for fusion between indexed_generic ops and tensor_reshape ops
Hanhan Wang
llvmlistbot at llvm.org
Wed Jun 3 15:00:06 PDT 2020
Author: Hanhan Wang
Date: 2020-06-03T14:59:47-07:00
New Revision: 27fca57546c2828e2684c02b7aa677cbd6603bfd
URL: https://github.com/llvm/llvm-project/commit/27fca57546c2828e2684c02b7aa677cbd6603bfd
DIFF: https://github.com/llvm/llvm-project/commit/27fca57546c2828e2684c02b7aa677cbd6603bfd.diff
LOG: [mlir][Linalg] Add support for fusion between indexed_generic ops and tensor_reshape ops
Summary:
The fusion for tensor_reshape is embedding the information to indexing maps,
thus the exising pattenr also works for indexed_generic ops.
Depends On D80347
Differential Revision: https://reviews.llvm.org/D80348
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-tensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9964e1355097..48231642eae7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -937,6 +937,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
folder);
+ } else if (auto indexedGenericOpConsumer =
+ dyn_cast<IndexedGenericOp>(consumer)) {
+ return FuseTensorReshapeOpAsProducer<IndexedGenericOp>::fuse(
+ reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter,
+ folder);
}
} else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
@@ -954,6 +959,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
if (genericOpProducer.hasTensorSemantics())
return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+ } else if (auto indexedGenericOpProducer =
+ dyn_cast<IndexedGenericOp>(producer)) {
+ if (indexedGenericOpProducer.hasTensorSemantics())
+ return FuseTensorReshapeOpAsConsumer<IndexedGenericOp>::fuse(
+ indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
}
return nullptr;
}
diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 9b73c02a4ed2..6d6a409edbd2 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -421,3 +421,68 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
// CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
// CHECK: linalg.yield %[[VAL4]] : i32
// CHECK-NOT: linalg.indexed_generic
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+ -> tensor<?x?x4x?xi32> {
+ %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k)>,
+ affine_map<(i, j, k, l) -> (l)>] :
+ tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
+ %1 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = addi %arg6, %2 : i32
+ linalg.yield %3 : i32
+ }: tensor<?x?x4x?xi32> -> tensor<?x?x4x?xi32>
+ return %1 : tensor<?x?x4x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+ -> tensor<?x?xi32> {
+ %0 = linalg.indexed_generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 {
+ ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors
+ %2 = index_cast %arg2 : index to i32
+ %3 = addi %arg6, %2 : i32
+ linalg.yield %3 : i32
+ }: tensor<?x?x4x5xi32> -> tensor<?x?x4x5xi32>
+ %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+ affine_map<(i, j, k, l) -> (j, k, l)>] :
+ tensor<?x?x4x5xi32> into tensor<?x?xi32>
+ return %1 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+// CHECK-NOT: linalg.tensor_reshape
+// CHECK: linalg.indexed_generic
+// CHECK-SAME: args_in = 1
+// CHECK-SAME: args_out = 1
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
+// CHECK-NOT: linalg.tensor_reshape
More information about the Mlir-commits
mailing list