[Mlir-commits] [mlir] 8ce6f7d - [mlir][Linalg] NFC - Fail gracefully instead of crashing in SplitReduction
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Feb 9 13:05:35 PST 2023
Author: Nicolas Vasilache
Date: 2023-02-09T12:59:57-08:00
New Revision: 8ce6f7dd9340abedb7065b9593b986e4c2ff4c02
URL: https://github.com/llvm/llvm-project/commit/8ce6f7dd9340abedb7065b9593b986e4c2ff4c02
DIFF: https://github.com/llvm/llvm-project/commit/8ce6f7dd9340abedb7065b9593b986e4c2ff4c02.diff
LOG: [mlir][Linalg] NFC - Fail gracefully instead of crashing in SplitReduction
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index f597273204f24..26dee47b8926a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -11,8 +11,8 @@
//
//===----------------------------------------------------------------------===//
-#include <utility>
#include <optional>
+#include <utility>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -42,15 +42,16 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
SmallVector<unsigned> dims;
op.getReductionDims(dims);
- assert(dims.size() == 1);
+
+ if (dims.size() != 1)
+ return b.notifyMatchFailure(op, "needs a single reduction dimension");
unsigned reductionDim = dims[0];
if (control.innerParallel) {
insertSplitDimension = reductionDim + 1;
}
SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
int64_t reductionDimSize = loopRanges[reductionDim];
- if (reductionDimSize == ShapedType::kDynamic ||
- reductionDimSize % ratio != 0)
+ if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
return b.notifyMatchFailure(
op, "Reduction dimension not divisible by split ratio");
if (op.getNumDpsInits() != 1)
@@ -85,19 +86,22 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
if (control.innerParallel) {
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
newShape.push_back(ratio); // parallel (insert)
- exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+ exprs.push_back(
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
} else {
newShape.push_back(ratio); // parallel (insert)
newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
- exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension? dim : dim + 1));
+ exprs.push_back(
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
reassociation.push_back({index++, index++});
continue;
}
newShape.push_back(op.getShape(operand)[idx]);
- exprs.push_back(b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
+ exprs.push_back(
+ b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
reassociation.push_back({index++});
}
newMaps.push_back(
More information about the Mlir-commits
mailing list