[llvm-branch-commits] [mlir] dddf6ab - Simplifying the SplitReduction logic that uses the control to get the

Murali Vijayaraghavan via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Nov 17 14:26:56 PST 2022


Author: Murali Vijayaraghavan
Date: 2022-11-17T22:26:02Z
New Revision: dddf6ab27212d9813a360eb95440c61e81a308be

URL: https://github.com/llvm/llvm-project/commit/dddf6ab27212d9813a360eb95440c61e81a308be
DIFF: https://github.com/llvm/llvm-project/commit/dddf6ab27212d9813a360eb95440c61e81a308be.diff

LOG: Simplifying the SplitReduction logic that uses the control to get the
dimension where the extra parallel dimension is inserted

Currently, the innerParallel and non innerParallel strategies use two
different ways to fix for where the extra loop is inserted and where the
extra dimension for the intermediate result is inserted - innerParallel
adds the extra (parallel) loop right after the pre-existing reduction
loop, whereas non innerParallel adds the reduction loop in the successor
to the index supplied by control, and the parallel loop in the index
supplied by the control. The semantics of the index supplied by the
control is supposed to only control where the extra tensor dimension is
inserted in the intermediate tensor. Conflating this index with where
the reduction (and parallel) loops are inserted leads to more complex
(and confusing) logic overall. This differential removes conflating the
two uses of the index, and keeps the reduction and parallel loops in the
same vicinity and uses the supplied index to only determine the position
of the extra tensor dimension. It also simplifies the code by merging
the two strategies in a lot more places.

Differential Revision: https://reviews.llvm.org/D137478

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 2fb550b27ec0c..26a49b91db1ed 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -34,7 +34,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   SplitReductionOptions control = controlSplitReductionFn(op);
   int64_t ratio = control.ratio;
-  unsigned insertSplitDimension = control.index;
+  unsigned insertSplitIndex = control.index;
   if (ratio <= 1)
     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
 
@@ -45,10 +45,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
   int64_t reductionDimSize = loopRanges[reductionDim];
   if (reductionDimSize == ShapedType::kDynamicSize ||
-      reductionDimSize % ratio != 0 ||
-      insertSplitDimension >= loopRanges.size())
+      reductionDimSize % ratio != 0)
     return b.notifyMatchFailure(
         op, "Reduction dimension not divisible by split ratio");
+  if (op.getNumDpsInits() != 1)
+    return b.notifyMatchFailure(op, "More than one output in split reduction");
+  if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
+    return b.notifyMatchFailure(op, "Insert dimension position too large "
+                                    "compared to intermediate tensor size");
 
   SmallVector<Operation *, 4> combinerOps;
   if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
@@ -80,25 +84,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
           newShape.push_back(ratio);
           newShape.push_back(op.getShape(operand)[idx] / ratio);
         }
+        exprs.push_back(b.getAffineDimExpr(reductionDim));
+        exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
         reassociation.push_back({index++, index++});
-        if (control.innerParallel) {
-          exprs.push_back(b.getAffineDimExpr(reductionDim));
-          exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
-        } else {
-          exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
-          exprs.push_back(
-              b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
-        }
         continue;
       }
       newShape.push_back(op.getShape(operand)[idx]);
-      if (control.innerParallel) {
-        exprs.push_back(
-            b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
-      } else {
-        exprs.push_back(
-            b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
-      }
+      exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
       reassociation.push_back({index++});
     }
     newMaps.push_back(
@@ -122,26 +114,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
   ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
   SmallVector<AffineExpr> outputExpr;
-  for (unsigned idx :
-       llvm::seq<unsigned>(0, oldOutputMap.getNumResults() + 1)) {
-    if (idx == insertSplitDimension) {
+  for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
+    if (insertSplitIndex == idx) {
       newOutputShape.push_back(ratio);
       if (control.innerParallel) {
         outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1));
       } else {
-        outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
+        outputExpr.push_back(b.getAffineDimExpr(reductionDim));
       }
-      continue;
     }
-    unsigned oldIdx = idx < insertSplitDimension ? idx : idx - 1;
-    newOutputShape.push_back(oldShape[oldIdx]);
-    unsigned dim = oldOutputMap.getDimPosition(oldIdx);
-    if (control.innerParallel) {
-      outputExpr.push_back(
-          b.getAffineDimExpr(dim <= reductionDim ? dim : dim + 1));
-    } else {
+    if (idx < oldShape.size()) {
+      newOutputShape.push_back(oldShape[idx]);
+      unsigned dim = oldOutputMap.getDimPosition(idx);
       outputExpr.push_back(
-          b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
+          b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
     }
   }
   Value emptyOrAllocTensor;
@@ -164,10 +150,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
                                    op.getContext()));
   SmallVector<utils::IteratorType> newIteratorTypes;
   for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
-    if (insertSplitDimension == it.index() && !control.innerParallel)
+    if (reductionDim == it.index() && !control.innerParallel)
       newIteratorTypes.push_back(utils::IteratorType::parallel);
     newIteratorTypes.push_back(it.value());
-    if (insertSplitDimension == it.index() && control.innerParallel)
+    if (reductionDim == it.index() && control.innerParallel)
       newIteratorTypes.push_back(utils::IteratorType::parallel);
   }
   // Create the new op matching the original op with an extra parallel
@@ -185,7 +171,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SmallVector<utils::IteratorType> reductionIteratorTypes;
   SmallVector<AffineExpr> exprs;
   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
-    if (insertSplitDimension == i) {
+    if (insertSplitIndex == i) {
       reductionIteratorTypes.push_back(utils::IteratorType::reduction);
     } else {
       exprs.push_back(b.getAffineDimExpr(i));

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index cb7a92198c04d..eb035d9ffe092 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -106,9 +106,9 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
   return %0 : tensor<5x2xf32>
 }
 
-//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
-//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
-//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>
 //  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d
@@ -117,7 +117,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
 //  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
-//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
 // CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
 //      CHECK:   arith.addf
 //      CHECK:   arith.maxf


        


More information about the llvm-branch-commits mailing list