[Mlir-commits] [mlir] eec9d0b - [mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface (#94579)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 17:50:22 PDT 2024


Author: zhicong zhong
Date: 2024-06-28T08:50:18+08:00
New Revision: eec9d0b6816e815fbe009941c1fda3b39c38adeb

URL: https://github.com/llvm/llvm-project/commit/eec9d0b6816e815fbe009941c1fda3b39c38adeb
DIFF: https://github.com/llvm/llvm-project/commit/eec9d0b6816e815fbe009941c1fda3b39c38adeb.diff

LOG: [mlir][Linalg] use linalg.reduce to simplify the mergeReductions in partialReductionInterface  (#94579)

The current implementation of `mergeReduction` in
`LinalgOpPartialReductionInterface` builds a `linalg.generic` from
scratch. While we already have `linalg.reduce` op which has the same
semantic as this generic op, this PR replaces the generic op with
`linalg.reduce` to simplify the implementation.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index b2a1e7c71f58e..39780490e9e49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -447,39 +447,10 @@ struct LinalgOpPartialReductionInterface
                                          Location loc, ValueRange partialReduce,
                                          ArrayRef<int> reductionDims) const {
     auto linalgOp = cast<LinalgOp>(op);
-
-    // Step 1. Recover the dims that actually need to be merged from the
-    // original operation. We can classify the original iterators as follows:
-    //
-    // parallel                         --> parallel
-    // reduction + not in reductionDims --> parallel (already reduced)
-    // reduction + in reductionDims     --> reduction (will reduce now)
-    SmallVector<utils::IteratorType> iterators(linalgOp.getNumLoops(),
-                                               utils::IteratorType::parallel);
-    for (int redIdx : reductionDims)
-      iterators[redIdx] = utils::IteratorType::reduction;
-
-    // Step 2. For each partial result, create a map to index it. This map
-    // is simply the indexing map for the original result with reductionDims
-    // appended (as produced in tileToPartialReduction).
-    int64_t numInits = linalgOp.getNumDpsInits();
-    SmallVector<AffineMap> indexingMaps(numInits * 2);
-    for (int idx : llvm::seq<int>(0, numInits)) {
-      AffineMap &inputMap = indexingMaps[idx];
-      AffineMap &outputMap = indexingMaps[numInits + idx];
-
-      outputMap =
-          linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
-      inputMap = outputMap;
-      for (int redPos : reductionDims) {
-        inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos),
-                                         inputMap.getNumResults());
-      }
-    }
-
-    auto reduction = b.create<GenericOp>(
-        loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(),
-        indexingMaps, iterators,
+    SmallVector<int64_t> reductionDimsInt64(reductionDims.begin(),
+                                            reductionDims.end());
+    auto reduction = b.create<linalg::ReduceOp>(
+        loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
         [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
           int64_t numInits = linalgOp.getNumDpsInits();
           SmallVector<Value> yieldedValues;

diff  --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 8feb3c2a2c306..cce4b4efa61c8 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -23,9 +23,8 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 //     CHECK: func @reduction_tile(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?xf32>
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C5:.*]] = arith.constant 5 : index
@@ -37,10 +36,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:   %[[E:.*]] = tensor.empty(%[[D2]]) : tensor<?x5xf32>
 //     CHECK:   %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor<?x5xf32>) -> tensor<?x5xf32>
 //     CHECK:   %[[L:.*]] = scf.for %[[K:.*]] = %[[C0]] to %[[D1]] step %[[C5]] iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<?x5xf32>) {
-//     CHECK:     %[[PS:.*]] = affine.min #[[MAP2]](%[[K]])[%[[D1]]]
+//     CHECK:     %[[PS:.*]] = affine.min #[[MAP0]](%[[K]])[%[[D1]]]
 //     CHECK:     %[[EXT2:.*]] = tensor.extract_slice %[[ARG0]][0, %[[K:.*]]] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
 //     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x5xf32> to tensor<?x?xf32>
-//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
+//     CHECK:     %[[PR:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[EXT2]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>) {
 //     CHECK:       arith.mulf
 //     CHECK:       arith.addf
 //     CHECK:       linalg.yield
@@ -48,10 +47,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
 //     CHECK:     scf.yield %[[INS]] : tensor<?x5xf32>
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -81,7 +80,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d1)>
 //     CHECK: func @reduction_tile_transpose
 //     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
 //     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
@@ -91,7 +89,7 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
 //     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
 //     CHECK:   }
-//     CHECK:   linalg.generic
+//     CHECK:   linalg.reduce
 //     CHECK:   return
 
 // -----
@@ -150,10 +148,11 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
+//     CHECK:   {
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -177,8 +176,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
 // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))>
-// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 //     CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>
 // CHECK-DAG:   %[[I:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -203,10 +200,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x?x5xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) dimensions = [2]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?x?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?x?xf32>
 
 // -----
@@ -270,10 +267,10 @@ module attributes {transform.with_named_sequence} {
 //     CHECK:       tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor<?xf32> into tensor<?x5xf32>
 //     CHECK:     }
 //     CHECK:   }
-//     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
+//     CHECK:   %[[R:.*]] = linalg.reduce ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) dimensions = [1]
 //     CHECK:     arith.addf
 //     CHECK:     linalg.yield
-//     CHECK:   } -> tensor<?xf32>
+//     CHECK:   }
 //     CHECK:   return %[[R]] : tensor<?xf32>
 
 // -----
@@ -307,7 +304,7 @@ module attributes {transform.with_named_sequence} {
     //      CHECK:     iterator_types = ["parallel", "reduction"]
     transform.print %2 {name = "expecting parallel reduction"} : !transform.any_op
     //      CHECK:     expecting parallel reduction
-    // CHECK-NEXT:     linalg.generic
+    // CHECK-NEXT:     linalg.reduce
     //      CHECK:     iterator_types = ["parallel", "reduction"]
     transform.print %3 {name = "expecting parallel reduction"} : !transform.any_op
     transform.yield
@@ -401,7 +398,7 @@ module {
 // CHECK:       %[[OUT:.*]] = linalg.generic  {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>)
 // CHECK:       scf.yield %[[OUT]] : tensor<4096x2x64xf32>
 // CHECK:     scf.yield %[[L1]] : tensor<4096x2x64xf32>
-// CHECK:   %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
+// CHECK:   %[[OUT2:.*]] = linalg.reduce ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>)
 // CHECK:  return %[[OUT2]] : tensor<4096xf32>
 
 // -----
@@ -445,6 +442,6 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       %[[INSERT1:.+]] = tensor.insert_slice %[[UPDATED]]#0 into %[[SUM]]
 // CHECK:       %[[INSERT2:.+]] = tensor.insert_slice %[[UPDATED]]#1 into %[[MAX]]
 // CHECK:       scf.yield %[[INSERT1]], %[[INSERT1]]
-// CHECK:       linalg.generic
+// CHECK:       linalg.reduce
 // CHECK:         arith.addf
 // CHECK:         arith.maximumf


        


More information about the Mlir-commits mailing list