[Mlir-commits] [mlir] [mlir][TOSA] Remove rollback from TOSA -> Linalg patterns (PR #136308)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 18 07:21:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Reorganize the implementation slightly, such that patterns check all preconditions before starting the actual rewrite. I.e., pattern no longer start rewriting and then abort, which would cause a pattern rollback. Pattern rollbacks are expensive and will be disallowed as part of the One-Shot Dialect Conversion refactoring.
---
Full diff: https://github.com/llvm/llvm-project/pull/136308.diff
1 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+119-98)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 9ca93ab28daed..bc4ef58cbcd62 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -91,6 +91,50 @@ createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
+/// Return "failure" if the given elementwise operation cannot be converted.
+static LogicalResult
+isSupportedElementwiseOperation(ConversionPatternRewriter &rewriter,
+ Operation *op, RankedTensorType resultType) {
+ auto elementTy =
+ cast<ShapedType>(op->getOperand(0).getType()).getElementType();
+
+ // tosa::MulOp
+ if (isa<tosa::MulOp>(op)) {
+ auto shiftVal = cast<tosa::MulOp>(op).getShift();
+ DenseElementsAttr shiftElem;
+ if (!matchPattern(shiftVal, m_Constant(&shiftElem)))
+ return rewriter.notifyMatchFailure(op, "shift value of mul not found");
+
+ int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+ if (isa<FloatType>(elementTy) && shift != 0)
+ return rewriter.notifyMatchFailure(op,
+ "Cannot have shift value for float");
+ return success();
+ }
+
+ // tosa::NegateOp
+ if (isa<tosa::NegateOp>(op)) {
+ auto negate = cast<tosa::NegateOp>(op);
+ if (failed(negate.getInput1ZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "input1 zero point cannot be statically determined");
+ if (failed(negate.getOutputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return success();
+ }
+
+ // tosa::CastOp
+ if (isa<tosa::CastOp>(op)) {
+ if (!elementTy.isIntOrFloat() ||
+ !resultType.getElementType().isIntOrFloat())
+ return rewriter.notifyMatchFailure(op, "unsupported type");
+ return success();
+ }
+
+ return success();
+}
+
static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) {
@@ -139,17 +183,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
auto shiftVal = cast<tosa::MulOp>(op).getShift();
DenseElementsAttr shiftElem;
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
- (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
- return nullptr;
+ llvm_unreachable("shift value of mul not found");
}
int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
- (void)rewriter.notifyMatchFailure(op,
- "Cannot have shift value for float");
- return nullptr;
+ llvm_unreachable("Cannot have shift value for float");
}
return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
}
@@ -196,16 +237,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
if (failed(maybeInZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input1 zero point cannot be statically determined");
- return nullptr;
+ llvm_unreachable("input1 zero point cannot be statically determined");
}
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
if (failed(maybeOutZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return nullptr;
+ llvm_unreachable("output zero point cannot be statically determined");
}
int64_t inZp = *maybeInZp;
@@ -548,10 +585,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::CastOp>(op)) {
Type srcTy = elementTy;
Type dstTy = resultTypes.front();
- if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
- (void)rewriter.notifyMatchFailure(op, "unsupported type");
- return nullptr;
- }
+ if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat())
+ llvm_unreachable("unsupported type");
bool bitExtend =
srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
@@ -706,8 +741,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
}
- (void)rewriter.notifyMatchFailure(
- op, "unhandled op for linalg body calculation for elementwise op");
+ llvm_unreachable(
+ "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
@@ -930,17 +965,11 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
});
}
-static LogicalResult
-emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
- Operation *operation, ValueRange operands,
- ArrayRef<OpFoldResult> targetShape,
- const TypeConverter &converter) {
+static LogicalResult emitElementwiseComputation(
+ ConversionPatternRewriter &rewriter, Location loc, Operation *operation,
+ ValueRange operands, ArrayRef<OpFoldResult> targetShape,
+ const TypeConverter &converter, RankedTensorType resultType) {
// Generate output tensor
- auto resultType = cast_or_null<RankedTensorType>(
- converter.convertType(operation->getResultTypes().front()));
- if (!resultType) {
- return rewriter.notifyMatchFailure(operation, "failed to convert type");
- }
Value outputTensor = rewriter.create<tensor::EmptyOp>(
loc, targetShape, resultType.getElementType());
@@ -967,7 +996,6 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
// Emit 'linalg.generic' op
- bool encounteredError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, outputTensor.getType(), operands, outputTensor, affineMaps,
getNParallelLoopsAttrs(rank),
@@ -975,15 +1003,10 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
{resultType.getElementType()}, rewriter);
- if (!opResult) {
- encounteredError = true;
- return;
- }
+ assert(opResult &&
+ "unable to create linalg.generic body for elementwise op");
opBuilder.create<linalg::YieldOp>(loc, opResult);
});
- if (encounteredError)
- return rewriter.notifyMatchFailure(
- operation, "unable to create linalg.generic body for elementwise op");
// Cast 'linalg.generic' result into original result type if needed
auto castResult = rewriter.createOrFold<tensor::CastOp>(
@@ -1008,13 +1031,20 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
const TypeConverter &converter) {
- // Collect op properties
+ // Check if operation is supported.
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
assert(operation->getNumOperands() >= 1 &&
"elementwise op expects at least 1 operand");
if (!operandsAndResultsRanked(operation))
return rewriter.notifyMatchFailure(operation,
"Unranked tensors not supported");
+ auto resultType = cast_or_null<RankedTensorType>(
+ converter.convertType(operation->getResultTypes().front()));
+ if (!resultType) {
+ return rewriter.notifyMatchFailure(operation, "failed to convert type");
+ }
+ if (failed(isSupportedElementwiseOperation(rewriter, operation, resultType)))
+ return failure();
// Lower operation
IndexPool indexPool;
@@ -1026,7 +1056,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
- targetShape, converter);
+ targetShape, converter, resultType);
}
// Returns the constant initial value for a given reduction operation. The
@@ -1126,7 +1156,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
return rewriter.create<arith::OrIOp>(loc, args);
- return {};
+ llvm_unreachable("unhandled reduction op");
}
// Performs the match and rewrite for reduction operations. This includes
@@ -1142,6 +1172,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
auto elementTy = resultTy.getElementType();
+ auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+ if (!fillValueAttr)
+ return rewriter.notifyMatchFailure(
+ op, "No initial value found for reduction operation");
Value input = op->getOperand(0);
SmallVector<int64_t> reduceShape;
@@ -1164,11 +1198,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
dynDims)
.getResult();
- auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
- if (!fillValueAttr)
- return rewriter.notifyMatchFailure(
- op, "No initial value found for reduction operation");
-
auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
auto filledTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{fillValue},
@@ -1212,7 +1241,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
}
}
- bool didEncounterError = false;
linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
loc, inputs, outputs, axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
@@ -1220,8 +1248,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, binaryArgs, elementTy, rewriter);
- if (result)
- didEncounterError = true;
+ assert(result && "could not create reduction body");
SmallVector<Value> resultsToYield;
if (isNanIgnoreMode) {
@@ -1247,10 +1274,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
});
- if (!didEncounterError)
- return rewriter.notifyMatchFailure(
- op, "unable to create linalg.generic body for reduce op");
-
if (isNanIgnoreMode) {
// Materialize a check to see whether we encountered any non-NaN values, if
// we didn't we need to select a tensor of NaNs since the result will just
@@ -1358,13 +1381,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
if (!isa<IntegerType>(inputTy.getElementType()))
return rewriter.notifyMatchFailure(op, "only support integer type");
- SmallVector<Value> dynDims;
- for (int i = 0; i < outputTy.getRank(); i++) {
- if (outputTy.isDynamicDim(i)) {
- dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
- }
- }
-
// The shift and multiplier values.
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
@@ -1376,6 +1392,21 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");
+ if (failed(op.getInputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "input zero point cannot be statically determined");
+
+ if (failed(op.getOutputZeroPoint()))
+ return rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+
+ SmallVector<Value> dynDims;
+ for (int i = 0; i < outputTy.getRank(); i++) {
+ if (outputTy.isDynamicDim(i)) {
+ dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+ }
+ }
+
llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
@@ -1473,23 +1504,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
auto inputZp = createConstOpFromZpVal<int32_t>(
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
-
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
-
auto outputZp = createConstOpFromZpVal<int32_t>(
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
@@ -1783,6 +1801,15 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
return rewriter.notifyMatchFailure(
op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
+ SmallVector<int64_t> scale, offset, border;
+ if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
+ !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
+ !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
+ return rewriter.notifyMatchFailure(
+ op, "tosa.resize scale/offset/border should have compile time "
+ "constant values.");
+ }
+
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
@@ -1810,15 +1837,6 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
- SmallVector<int64_t> scale, offset, border;
- if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
- !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
- !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
- return rewriter.notifyMatchFailure(
- op, "tosa.resize scale/offset/border should have compile time "
- "constant values.");
- }
-
Value yScaleN, yScaleD, xScaleN, xScaleD;
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
@@ -2204,6 +2222,9 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
auto inElementTy = inputTy.getElementType();
+ if (!isa<IntegerType, FloatType>(inElementTy))
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
auto outElementTy = resultTy.getElementType();
int axis = argmaxOp.getAxis();
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
@@ -2213,6 +2234,12 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
argmaxOp,
"tosa.arg_max to linalg.* requires integer-like result type");
+ auto fillValueMaxAttr =
+ createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
+ if (!fillValueMaxAttr)
+ return rewriter.notifyMatchFailure(
+ argmaxOp, "unsupported tosa.argmax element type");
+
SmallVector<Value> dynDims;
for (int i = 0; i < inputTy.getRank(); i++) {
if (inputTy.isDynamicDim(i) && i != axis) {
@@ -2238,12 +2265,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
.create<tensor::EmptyOp>(loc, resultTy.getShape(),
inElementTy, dynDims)
.getResult();
- auto fillValueMaxAttr =
- createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
-
- if (!fillValueMaxAttr)
- return rewriter.notifyMatchFailure(
- argmaxOp, "unsupported tosa.argmax element type");
auto fillValueMax =
rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
@@ -2267,7 +2288,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
- bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
rewriter.getContext());
auto linalgOp = rewriter.create<linalg::GenericOp>(
@@ -2305,8 +2325,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
predicate = rewriter.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
} else {
- didEncounterError = true;
- return;
+ llvm_unreachable("unsupported tosa.argmax element type");
}
auto resultMax = rewriter.create<arith::SelectOp>(
@@ -2317,10 +2336,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
nestedLoc, ValueRange({resultIndex, resultMax}));
});
- if (didEncounterError)
- return rewriter.notifyMatchFailure(
- argmaxOp, "unsupported tosa.argmax element type");
-
rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
return success();
}
@@ -2416,6 +2431,15 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto tableElementTy = tableTy.getElementType();
auto resultElementTy = resultTy.getElementType();
+ bool isI8_8_8 = inputElementTy.isInteger(8) &&
+ tableElementTy.isInteger(8) && resultElementTy.isInteger(8);
+ bool isI16_16_32 = inputElementTy.isInteger(16) &&
+ tableElementTy.isInteger(16) &&
+ resultElementTy.isInteger(32);
+ if (!isI8_8_8 && !isI16_16_32)
+ return rewriter.notifyMatchFailure(
+ op, "unable to create body for tosa.table op");
+
SmallVector<Value> dynDims;
for (int i = 0; i < resultTy.getRank(); ++i) {
if (inputTy.isDynamicDim(i)) {
@@ -2446,8 +2470,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
auto inputValue = block->getArgument(0);
rewriter.setInsertionPointToStart(block);
- if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
- resultElementTy.isInteger(8)) {
+ if (isI8_8_8) {
Value index = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), inputValue);
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
@@ -2459,8 +2482,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
return success();
}
- if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
- resultElementTy.isInteger(32)) {
+ if (isI16_16_32) {
Value extend = rewriter.create<arith::ExtSIOp>(
loc, rewriter.getI32Type(), inputValue);
@@ -2516,8 +2538,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
}
}
- return rewriter.notifyMatchFailure(
- op, "unable to create body for tosa.table op");
+ llvm_unreachable("unable to create body for tosa.table op");
}
};
``````````
</details>
https://github.com/llvm/llvm-project/pull/136308
More information about the Mlir-commits
mailing list