[Mlir-commits] [mlir] [mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface (PR #94579)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 5 22:49:47 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: zhicong zhong (zhczhong)
<details>
<summary>Changes</summary>
The current implementation of `mergeReduction` in `LinalgOpPartialReductionInterface` builds a `linalg.generic` from scratch. While we already have `linalg.reduce` op which has the same semantic as this generic op. This PR replaces the generic op with `linalg.reduce` to simplify the implementation.
---
Full diff: https://github.com/llvm/llvm-project/pull/94579.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+5-33)
- (modified) mlir/test/Dialect/Linalg/transform-tile-reduction.mlir (+17-20)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..c038d03c15342 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -443,40 +443,12 @@ struct LinalgOpPartialReductionInterface
Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc,
ValueRange partialReduce,
ArrayRef<int> reductionDims) const {
- auto linalgOp = cast<LinalgOp>(op);
- // Step 1. Recover the dims that actually need to be merged from the
- // original operation. We can classify the original iterators as follows:
- //
- // parallel --> parallel
- // reduction + not in reductionDims --> parallel (already reduced)
- // reduction + in reductionDims --> reduction (will reduce now)
- SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
- utils::IteratorType::parallel);
- for (int redIdx : reductionDims)
- iterators[redIdx] = utils::IteratorType::reduction;
-
- // Step 2. For each partial result, create a map to index it. This map
- // is simply the indexing map for the original result with reductionDims
- // appended (as produced in tileToPartialReduction).
- int64_t numInits = linalgOp.getNumDpsInits();
- SmallVector<AffineMap> indexingMaps(numInits * 2);
- for (int idx : llvm::seq<int>(0, numInits)) {
- AffineMap &inputMap = indexingMaps[idx];
- AffineMap &outputMap = indexingMaps[numInits + idx];
-
- outputMap =
- linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
- inputMap = outputMap;
- for (int redPos : reductionDims) {
- inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
- inputMap.getNumResults());
- }
- }
-
- auto reduction = b.create<GenericOp>(
- loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
- indexingMaps, iterators,
+ auto linalgOp = cast<LinalgOp>(op);
+ SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
+ reductionDims.end());
+ auto reduction = b.create<linalg::ReduceOp>(
+ loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
int64_t numInits = linalgOp.getNumDpsInits();
SmallVector<Value> yieldedValues;
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index f3cf7c4dffa05..4a8bb42676fdb 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
@@ -37,10 +36,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
// CHECK: %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
-// CHECK: %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+// CHECK: %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
// CHECK: %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
-// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+// CHECK: %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
@@ -48,10 +47,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: func @reduction_tile_transpose
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
@@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
// CHECK: }
-// CHECK: linalg.generic
+// CHECK: linalg.reduce
// CHECK: return
// -----
@@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
+// CHECK: {
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?x?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?x?xf32>
// -----
@@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
// CHECK: }
// CHECK: }
-// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+// CHECK: %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
// CHECK: arith.addf
// CHECK: linalg.yield
-// CHECK: } -> tensor<?xf32>
+// CHECK: }
// CHECK: return %[[R]] : tensor<?xf32>
// -----
@@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
// CHECK: expecting parallel reduction
- // CHECK-NEXT: linalg.generic
+ // CHECK-NEXT: linalg.reduce
// CHECK: iterator_types = ["parallel", "reduction"]
transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
transform.yield
@@ -402,7 +399,7 @@ module {
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32>
// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32>
-// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK: %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
// CHECK: return %[[OUT2]] : tensor<4096xf32>
// -----
@@ -446,6 +443,6 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
// CHECK: %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
// CHECK: scf.yield %[[INSERT1]], %[[INSERT1]]
-// CHECK: linalg.generic
+// CHECK: linalg.reduce
// CHECK: arith.addf
// CHECK: arith.maximumf
``````````
</details>
https://github.com/llvm/llvm-project/pull/94579
More information about the Mlir-commits
mailing list