[Mlir-commits] [mlir] [mlir][linalg] Enable expansion of parallel dims of reduction ops (PR #83473)

Quinn Dawkins llvmlistbot at llvm.org
Thu Feb 29 13:21:38 PST 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/83473

>From 542787dcd06b8a091ce963e9d2f0deb7e4855035 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 29 Feb 2024 14:47:06 -0500
Subject: [PATCH 1/2] [mlir][linalg] Enable expansion of parallel dims of
 reduction ops

This adds support for expansion of linalg ops with reduction iterators.
This improves the ability to make fusion decisions WRT reduction
operations. To recover the previous behavior, users of the patterns can
add a control function to restrict propagation of reshape by expansion
through linalg ops with reduction iterators.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 18 +++-
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 90 +++++++++++++++++++
 2 files changed, 104 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 4797bfb2267d7f..6310f9105960be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -526,7 +526,10 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
   // - All the indexing maps for operands and results are projected
   //   permutations.
   // - The fused tensor is not a scalar.
-  // - All the loops are parallel loops.
+  // - All the loops for the reshaped operand are parallel loops.
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  AffineMap operandMap = genericOp.getMatchingIndexingMap(fusableOpOperand);
   return genericOp.hasPureTensorSemantics() &&
          llvm::all_of(genericOp.getIndexingMaps().getValue(),
                       [](Attribute attr) {
@@ -534,9 +537,11 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
                             .getValue()
                             .isProjectedPermutation();
                       }) &&
-         genericOp.getMatchingIndexingMap(fusableOpOperand).getNumResults() >
-             0 &&
-         llvm::all_of(genericOp.getIteratorTypesArray(), isParallelIterator);
+         operandMap.getNumResults() > 0 &&
+         llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
+           return isParallelIterator(
+               iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
+         });
 }
 
 namespace {
@@ -848,6 +853,11 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
+    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
+    for (auto i : group)
+      iteratorTypes[i] = type;
+  }
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 0e40b5fbed97cb..5c0a83258b4b95 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -573,3 +573,93 @@ module {
 // CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
 // CHECK-SAME:       outs(%[[ARG2]], %[[OUTS]] :
 //      CHECK:   return %[[GENERIC]]#1
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor<?x?xf32>,
+                                                        %arg1 : tensor<?x?xf32>,
+                                                        %arg2 : tensor<?x?xf32>) ->
+                                                        tensor<?x?x4x5xf32>
+{
+  %0 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "parallel", "reduction"]}
+       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %s: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0], [1, 2, 3]] :
+    tensor<?x?xf32> into tensor<?x?x4x5xf32>
+  return %1 : tensor<?x?x4x5xf32>
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+//      CHECK: func @generic_op_reshape_consumer_fusion_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1, 2], [3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x4x5x?xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0], [1, 2, 3]
+// CHECK-SAME:     tensor<?x?xf32> into tensor<?x?x4x5xf32>
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel", "reduction"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x?xf32>, tensor<?x4x5x?xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x?x4x5xf32>)
+//      CHECK:   return %[[T3]] : tensor<?x?x4x5xf32>
+
+// -----
+
+#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
+func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor<?x7x?x8xf32>,
+                                         %arg1 : tensor<?x4x?xf32>,
+                                         %arg2 : tensor<?x?xf32>) ->
+                                         tensor<?x?xf32>
+{
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
+    tensor<?x7x?x8xf32> into tensor<?x?xf32>
+  %1 = linalg.generic {
+     indexing_maps = [#map0, #map1, #map2],
+     iterator_types = ["parallel", "reduction", "parallel"]}
+       ins(%0, %arg1 : tensor<?x?xf32>, tensor<?x4x?xf32>)
+       outs(%arg2 : tensor<?x?xf32>) {
+    ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
+      %1 = arith.mulf %arg3, %arg4 : f32
+      %2 = arith.addf %1, %arg5 : f32
+      linalg.yield %2 : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+//  CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d0, d1)>
+//  CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+//  CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>
+//      CHECK: func @generic_op_reshape_producer_fusion_with_reduction
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x4x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]]
+// CHECK-SAME:     [0, 1], [2], [3, 4]
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG2]]
+// CHECK-SAME:     [0, 1], [2, 3]
+//      CHECK:   %[[T3:.+]] = linalg.generic
+// CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:     ["parallel", "parallel", "reduction", "parallel", "parallel"]
+// CHECK-SAME:     ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x8x4x?x7xf32>)
+// CHECK-SAME:     outs(%[[T2]] : tensor<?x8x?x7xf32>)
+//      CHECK:   %[[T4:.+]] = tensor.collapse_shape %[[T3]]
+// CHECK-SAME:     [0, 1], [2, 3]
+// CHECK-SAME:     tensor<?x8x?x7xf32> into tensor<?x?xf32>
+//      CHECK:   return %[[T4]]

>From c0b6a7ac1d6a57dda17b737312f16fc660bc4e63 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Thu, 29 Feb 2024 16:21:18 -0500
Subject: [PATCH 2/2] Address comments

---
 .../lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6310f9105960be..402a7d58333a5c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -853,11 +853,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
   // The iterator types of the expanded op are all parallel.
   SmallVector<utils::IteratorType> iteratorTypes(
       expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
-  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray())) {
-    ReassociationIndicesRef group = expansionInfo.getExpandedDims(i);
-    for (auto i : group)
-      iteratorTypes[i] = type;
-  }
+  for (auto [i, type] : llvm::enumerate(genericOp.getIteratorTypesArray()))
+    for (auto j : expansionInfo.getExpandedDims(i))
+      iteratorTypes[j] = type;
 
   TypeRange resultTypes = ValueRange(outputs).getTypes();
   auto fusedOp =



More information about the Mlir-commits mailing list