[Mlir-commits] [mlir] 2d2cdf4 - [mlir][linalg] Changing the positions of introduced parallel loop in SplitReduction to be consistent with IREE's downstream passes

Murali Vijayaraghavan llvmlistbot at llvm.org
Tue Nov 29 20:07:48 PST 2022


Author: Murali Vijayaraghavan
Date: 2022-11-30T04:01:07Z
New Revision: 2d2cdf41763236e280a7e1c815f8e32441ee1f15

URL: https://github.com/llvm/llvm-project/commit/2d2cdf41763236e280a7e1c815f8e32441ee1f15
DIFF: https://github.com/llvm/llvm-project/commit/2d2cdf41763236e280a7e1c815f8e32441ee1f15.diff

LOG: [mlir][linalg] Changing the positions of introduced parallel loop in SplitReduction to be consistent with IREE's downstream passes

IREE's passes depend on the behavior of SplitReduction's introduced
parallel loop being the same as the introduced dimension in the
intermediate tensor (the order of loops was changed in
https://reviews.llvm.org/D137478).

Differential Revision: https://reviews.llvm.org/D138972

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index efa6b1f91063e..7282215999fd4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -35,6 +35,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   SplitReductionOptions control = controlSplitReductionFn(op);
   int64_t ratio = control.ratio;
   unsigned insertSplitIndex = control.index;
+  unsigned insertSplitDimension = control.index;
   if (ratio <= 1)
     return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
 
@@ -42,6 +43,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   op.getReductionDims(dims);
   assert(dims.size() == 1);
   unsigned reductionDim = dims[0];
+  if (control.innerParallel) {
+    insertSplitDimension = reductionDim + 1;
+  }
   SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
   int64_t reductionDimSize = loopRanges[reductionDim];
   if (reductionDimSize == ShapedType::kDynamic ||
@@ -78,19 +82,21 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
       unsigned dim = map.getDimPosition(idx);
       if (reductionDim == dim) {
         if (control.innerParallel) {
-          newShape.push_back(op.getShape(operand)[idx] / ratio);
-          newShape.push_back(ratio);
+          newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
+          newShape.push_back(ratio); // parallel (insert)
+          exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+          exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
         } else {
-          newShape.push_back(ratio);
-          newShape.push_back(op.getShape(operand)[idx] / ratio);
+          newShape.push_back(ratio); // parallel (insert)
+          newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
+          exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
+          exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
         }
-        exprs.push_back(b.getAffineDimExpr(reductionDim));
-        exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
         reassociation.push_back({index++, index++});
         continue;
       }
       newShape.push_back(op.getShape(operand)[idx]);
-      exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
+      exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
       reassociation.push_back({index++});
     }
     newMaps.push_back(
@@ -117,17 +123,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
     if (insertSplitIndex == idx) {
       newOutputShape.push_back(ratio);
-      if (control.innerParallel) {
-        outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1));
-      } else {
-        outputExpr.push_back(b.getAffineDimExpr(reductionDim));
-      }
+      outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
     }
     if (idx < oldShape.size()) {
       newOutputShape.push_back(oldShape[idx]);
       unsigned dim = oldOutputMap.getDimPosition(idx);
       outputExpr.push_back(
-          b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
+          b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
     }
   }
   Value emptyOrAllocTensor;
@@ -150,11 +152,12 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
                                    op.getContext()));
   SmallVector<utils::IteratorType> newIteratorTypes;
   for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
-    if (reductionDim == it.index() && !control.innerParallel)
+    if (insertSplitDimension == it.index())
       newIteratorTypes.push_back(utils::IteratorType::parallel);
     newIteratorTypes.push_back(it.value());
-    if (reductionDim == it.index() && control.innerParallel)
-      newIteratorTypes.push_back(utils::IteratorType::parallel);
+  }
+  if (insertSplitDimension == op.getIteratorTypesArray().size()) {
+    newIteratorTypes.push_back(utils::IteratorType::parallel);
   }
   // Create the new op matching the original op with an extra parallel
   // dimension.

diff  --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index 2500b2c46b17a..7b4d99d126431 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -106,9 +106,9 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
   return %0 : tensor<5x2xf32>
 }
 
-//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
-//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
-//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>
+//  CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
+//  CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+//  CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
 //  CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 //  CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 // CHECK-LABEL:  func @generic_split_3d
@@ -117,7 +117,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
 //  CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
 //  CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
 //      CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
-//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+//      CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
 // CHECK-SAME:   ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
 //      CHECK:   arith.addf
 //      CHECK:   arith.maxf


        


More information about the Mlir-commits mailing list