[Mlir-commits] [mlir] 27fca57 - [mlir][Linalg] Add support for fusion between indexed_generic ops and tensor_reshape ops

Hanhan Wang llvmlistbot at llvm.org
Wed Jun 3 15:00:06 PDT 2020


Author: Hanhan Wang
Date: 2020-06-03T14:59:47-07:00
New Revision: 27fca57546c2828e2684c02b7aa677cbd6603bfd

URL: https://github.com/llvm/llvm-project/commit/27fca57546c2828e2684c02b7aa677cbd6603bfd
DIFF: https://github.com/llvm/llvm-project/commit/27fca57546c2828e2684c02b7aa677cbd6603bfd.diff

LOG: [mlir][Linalg] Add support for fusion between indexed_generic ops and tensor_reshape ops

Summary:
The fusion for tensor_reshape is embedding the information to indexing maps,
thus the exising pattenr also works for indexed_generic ops.

Depends On D80347

Differential Revision: https://reviews.llvm.org/D80348

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 9964e1355097..48231642eae7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -937,6 +937,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
         return FuseTensorReshapeOpAsProducer<GenericOp>::fuse(
             reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter,
             folder);
+      } else if (auto indexedGenericOpConsumer =
+                     dyn_cast<IndexedGenericOp>(consumer)) {
+        return FuseTensorReshapeOpAsProducer<IndexedGenericOp>::fuse(
+            reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter,
+            folder);
       }
     } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) {
       if (auto genericOpConsumer = dyn_cast<GenericOp>(consumer)) {
@@ -954,6 +959,11 @@ Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter,
       if (genericOpProducer.hasTensorSemantics())
         return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse(
             genericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
+    } else if (auto indexedGenericOpProducer =
+                   dyn_cast<IndexedGenericOp>(producer)) {
+      if (indexedGenericOpProducer.hasTensorSemantics())
+        return FuseTensorReshapeOpAsConsumer<IndexedGenericOp>::fuse(
+            indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder);
     }
     return nullptr;
   }

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index 9b73c02a4ed2..6d6a409edbd2 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -421,3 +421,68 @@ func @indexed_generic_op_fusion(%arg0: tensor<?x?xi32>) {
 //      CHECK:   %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32
 //      CHECK:   linalg.yield %[[VAL4]] : i32
 //   CHECK-NOT: linalg.indexed_generic
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x?xi32>)
+  -> tensor<?x?x4x?xi32> {
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>,
+                                    affine_map<(i, j, k, l) -> (j, k)>,
+                                    affine_map<(i, j, k, l) -> (l)>] :
+    tensor<?x?x?xi32> into tensor<?x?x4x?xi32>
+  %1 = linalg.indexed_generic {
+    args_in = 1 : i64,
+    args_out = 1 : i64,
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
+    %2 = index_cast %arg2 : index to i32
+    %3 = addi %arg6, %2 : i32
+    linalg.yield %3 : i32
+  }: tensor<?x?x4x?xi32> -> tensor<?x?x4x?xi32>
+  return %1 : tensor<?x?x4x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.indexed_generic
+//  CHECK-SAME:   args_in = 1
+//  CHECK-SAME:   args_out = 1
+//  CHECK-SAME:   indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
+  -> tensor<?x?xi32> {
+  %0 = linalg.indexed_generic {
+    args_in = 1 : i64,
+    args_out = 1 : i64,
+    indexing_maps = [#map0, #map0],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 {
+  ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32):       // no predecessors
+    %2 = index_cast %arg2 : index to i32
+    %3 = addi %arg6, %2 : i32
+    linalg.yield %3 : i32
+  }: tensor<?x?x4x5xi32> -> tensor<?x?x4x5xi32>
+  %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
+                                 affine_map<(i, j, k, l) -> (j, k, l)>] :
+    tensor<?x?x4x5xi32> into tensor<?x?xi32>
+  return %1 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.indexed_generic
+//  CHECK-SAME:   args_in = 1
+//  CHECK-SAME:   args_out = 1
+//  CHECK-SAME:   indexing_maps = [#[[MAP0]], #[[MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape


        


More information about the Mlir-commits mailing list