[Mlir-commits] [mlir] 2d2cdf4 - [mlir][linalg] Changing the positions of introduced parallel loop in SplitReduction to be consistent with IREE's downstream passes
Murali Vijayaraghavan
llvmlistbot at llvm.org
Tue Nov 29 20:07:48 PST 2022
Author: Murali Vijayaraghavan
Date: 2022-11-30T04:01:07Z
New Revision: 2d2cdf41763236e280a7e1c815f8e32441ee1f15
URL: https://github.com/llvm/llvm-project/commit/2d2cdf41763236e280a7e1c815f8e32441ee1f15
DIFF: https://github.com/llvm/llvm-project/commit/2d2cdf41763236e280a7e1c815f8e32441ee1f15.diff
LOG: [mlir][linalg] Changing the positions of introduced parallel loop in SplitReduction to be consistent with IREE's downstream passes
IREE's passes depend on the behavior of SplitReduction's introduced
parallel loop being the same as the introduced dimension in the
intermediate tensor (the order of loops was changed in
https://reviews.llvm.org/D137478).
Differential Revision: https://reviews.llvm.org/D138972
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index efa6b1f91063e..7282215999fd4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -35,6 +35,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SplitReductionOptions control = controlSplitReductionFn(op);
int64_t ratio = control.ratio;
unsigned insertSplitIndex = control.index;
+ unsigned insertSplitDimension = control.index;
if (ratio <= 1)
return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
@@ -42,6 +43,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
op.getReductionDims(dims);
assert(dims.size() == 1);
unsigned reductionDim = dims[0];
+ if (control.innerParallel) {
+ insertSplitDimension = reductionDim + 1;
+ }
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
int64_t reductionDimSize = loopRanges[reductionDim];
if (reductionDimSize == ShapedType::kDynamic ||
@@ -78,19 +82,21 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
unsigned dim = map.getDimPosition(idx);
if (reductionDim == dim) {
if (control.innerParallel) {
- newShape.push_back(op.getShape(operand)[idx] / ratio);
- newShape.push_back(ratio);
+ newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
+ newShape.push_back(ratio); // parallel (insert)
+ exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+ exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
} else {
- newShape.push_back(ratio);
- newShape.push_back(op.getShape(operand)[idx] / ratio);
+ newShape.push_back(ratio); // parallel (insert)
+ newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
+ exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
+ exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
}
- exprs.push_back(b.getAffineDimExpr(reductionDim));
- exprs.push_back(b.getAffineDimExpr(reductionDim + 1));
reassociation.push_back({index++, index++});
continue;
}
newShape.push_back(op.getShape(operand)[idx]);
- exprs.push_back(b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
+ exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
reassociation.push_back({index++});
}
newMaps.push_back(
@@ -117,17 +123,13 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
if (insertSplitIndex == idx) {
newOutputShape.push_back(ratio);
- if (control.innerParallel) {
- outputExpr.push_back(b.getAffineDimExpr(reductionDim + 1));
- } else {
- outputExpr.push_back(b.getAffineDimExpr(reductionDim));
- }
+ outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
}
if (idx < oldShape.size()) {
newOutputShape.push_back(oldShape[idx]);
unsigned dim = oldOutputMap.getDimPosition(idx);
outputExpr.push_back(
- b.getAffineDimExpr(dim < reductionDim ? dim : dim + 1));
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
}
Value emptyOrAllocTensor;
@@ -150,11 +152,12 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
op.getContext()));
SmallVector<utils::IteratorType> newIteratorTypes;
for (auto &it : llvm::enumerate(op.getIteratorTypesArray())) {
- if (reductionDim == it.index() && !control.innerParallel)
+ if (insertSplitDimension == it.index())
newIteratorTypes.push_back(utils::IteratorType::parallel);
newIteratorTypes.push_back(it.value());
- if (reductionDim == it.index() && control.innerParallel)
- newIteratorTypes.push_back(utils::IteratorType::parallel);
+ }
+ if (insertSplitDimension == op.getIteratorTypesArray().size()) {
+ newIteratorTypes.push_back(utils::IteratorType::parallel);
}
// Create the new op matching the original op with an extra parallel
// dimension.
diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
index 2500b2c46b17a..7b4d99d126431 100644
--- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
@@ -106,9 +106,9 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
return %0 : tensor<5x2xf32>
}
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
-// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @generic_split_3d
@@ -117,7 +117,7 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
-// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
+// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
// CHECK: arith.addf
// CHECK: arith.maxf
More information about the Mlir-commits
mailing list