[Mlir-commits] [mlir] 7060920 - Relax FuseTensorReshapeOpAsproducer identity mapping constraint

Ahmed S. Taei llvmlistbot at llvm.org
Tue Oct 6 15:32:19 PDT 2020


Author: Ahmed S. Taei
Date: 2020-10-06T22:31:39Z
New Revision: 7060920bd1f70b778105703a5c95066658ed5886

URL: https://github.com/llvm/llvm-project/commit/7060920bd1f70b778105703a5c95066658ed5886
DIFF: https://github.com/llvm/llvm-project/commit/7060920bd1f70b778105703a5c95066658ed5886.diff

LOG: Relax FuseTensorReshapeOpAsproducer identity mapping constraint

Differential Revision: https://reviews.llvm.org/D88869

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/fusion-tensor.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index a62b1ada2c18..ac57d5f97c1d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -326,7 +326,7 @@ static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp,
   if ((asProducer && returnType.getRank() < operandType.getRank()) ||
       (!asProducer && operandType.getRank() < returnType.getRank()))
     return false;
-  return useIndexMap.isIdentity();
+  return useIndexMap.isPermutation();
 }
 
 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
@@ -381,10 +381,13 @@ struct FuseTensorReshapeOpAsProducer {
               return attr.cast<AffineMapAttr>().getValue();
             }));
 
+    // Accepted consumer maps are either identity or permutation.
+    auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]);
+
     // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap = linearizeCollapsedDims(
-        fusedIndexMaps[consumerIdx], producer.getResultType().getShape(),
-        producer.getReassociationMaps());
+    AffineMap modifiedMap =
+        linearizeCollapsedDims(invMap, producer.getResultType().getShape(),
+                               producer.getReassociationMaps());
     for (AffineExpr expr : modifiedMap.getResults()) {
       if (!expr.isPureAffine())
         return nullptr;
@@ -439,10 +442,13 @@ struct FuseTensorReshapeOpAsConsumer {
             producer.indexing_maps(), [](Attribute attr) -> AffineMap {
               return attr.cast<AffineMapAttr>().getValue();
             }));
+
+    auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
+
     // Compute the indexing map to use for the operand of the producer.
-    AffineMap modifiedMap = linearizeCollapsedDims(
-        producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(),
-        consumer.getReassociationMaps());
+    AffineMap modifiedMap =
+        linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(),
+                               consumer.getReassociationMaps());
     for (AffineExpr expr : modifiedMap.getResults()) {
       if (!expr.isPureAffine())
         return nullptr;

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index ccadff54e40b..3f8b0680d7a4 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -558,3 +558,100 @@ func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?x4x5xi32>)
 //       CHECK: linalg.indexed_generic
 //  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
 //   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<3x7x5xf32>
+    return %1 : tensor<3x7x5xf32>
+}
+
+// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<5x7x3xf32>
+    return %1 : tensor<5x7x3xf32>
+}
+
+// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#map0 = affine_map<(d0, d1, d2) -> (d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> {
+  %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32>
+  %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+    } -> tensor<5x3x7xf32>
+    return %1 : tensor<5x3x7xf32>
+}
+
+// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)>
+
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0)>
+#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
+func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> {
+  %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) {
+    ^bb0(%arg2: f32):  // no predecessors
+      linalg.yield %arg2 : f32
+  } -> tensor<5x3x7xf32>
+  %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32>
+  return %1 : tensor<5x21xf32>
+}
+
+// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion
+//   CHECK-NOT: linalg.tensor_reshape
+//       CHECK: linalg.generic
+//  CHECK-SAME:   indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
+//   CHECK-NOT: linalg.tensor_reshape


        


More information about the Mlir-commits mailing list