[Mlir-commits] [mlir] [mlir][linalg] Enable expansion of parallel dims of reduction ops (PR #83473)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 29 12:20:31 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Quinn Dawkins (qedawkins)

<details>
<summary>Changes</summary>

This adds support for expansion of linalg ops with reduction iterators. This improves the ability to make fusion decisions WRT reduction operations. To recover the previous behavior, users of the patterns can add a control function to restrict propagation of reshape by expansion through linalg ops with reduction iterators.

---
Full diff: https://github.com/llvm/llvm-project/pull/83473.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+14-4) 
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+90) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..6310f9105960be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -526,7 +526,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops are parallel loops.
+  // - All the loops for the reshaped operand are parallel loops.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
   return genericOp.hasPureTensorSemantics() &&
          llvm::all_of(genericOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
@@ -534,9 +537,11 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         operandMap.getNumResults() > 0 &&
+         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+           return isParallelIterator(
+               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+         });
 }
 
 namespace {
@@ -848,6 +853,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
+    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
+    for (auto i : group)
+      iteratorTypes[i] = type;
+  }
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..5c0a83258b4b95 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,93 @@ module {
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
 //      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+                                                        %arg1 : tensor<?x?xf32>,
+                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1, 2], [3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+                                         %arg1 : tensor<?x4x?xf32>,
+                                         %arg2 : tensor<?x?xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      %2 = arith.addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2], [3, 4]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/83473


More information about the Mlir-commits mailing list