[Mlir-commits] [mlir] [mlir][TOSA] Remove rollback from TOSA -> Linalg patterns (PR #136308)

Matthias Springer llvmlistbot at llvm.org
Fri Apr 18 07:20:51 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136308

Reorganize the implementation slightly, such that patterns check all preconditions before starting the actual rewrite. I.e., pattern no longer start rewriting and then abort, which would cause a pattern rollback. Pattern rollbacks are expensive and will be disallowed as part of the One-Shot Dialect Conversion refactoring.


>From c5cf03a74d5a27d6f03da7d3c21532e98906c638 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 18 Apr 2025 16:14:03 +0200
Subject: [PATCH] fix tosa

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 217 ++++++++++--------
 1 file changed, 119 insertions(+), 98 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 9ca93ab28daed..bc4ef58cbcd62 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -91,6 +91,50 @@ createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
 
+/// Return "failure" if the given elementwise operation cannot be converted.
+static LogicalResult
+isSupportedElementwiseOperation(ConversionPatternRewriter &rewriter,
+                                Operation *op, RankedTensorType resultType) {
+  auto elementTy =
+      cast<ShapedType>(op->getOperand(0).getType()).getElementType();
+
+  // tosa::MulOp
+  if (isa<tosa::MulOp>(op)) {
+    auto shiftVal = cast<tosa::MulOp>(op).getShift();
+    DenseElementsAttr shiftElem;
+    if (!matchPattern(shiftVal, m_Constant(&shiftElem)))
+      return rewriter.notifyMatchFailure(op, "shift value of mul not found");
+
+    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
+    if (isa<FloatType>(elementTy) && shift != 0)
+      return rewriter.notifyMatchFailure(op,
+                                         "Cannot have shift value for float");
+    return success();
+  }
+
+  // tosa::NegateOp
+  if (isa<tosa::NegateOp>(op)) {
+    auto negate = cast<tosa::NegateOp>(op);
+    if (failed(negate.getInput1ZeroPoint()))
+      return rewriter.notifyMatchFailure(
+          op, "input1 zero point cannot be statically determined");
+    if (failed(negate.getOutputZeroPoint()))
+      return rewriter.notifyMatchFailure(
+          op, "output zero point cannot be statically determined");
+    return success();
+  }
+
+  // tosa::CastOp
+  if (isa<tosa::CastOp>(op)) {
+    if (!elementTy.isIntOrFloat() ||
+        !resultType.getElementType().isIntOrFloat())
+      return rewriter.notifyMatchFailure(op, "unsupported type");
+    return success();
+  }
+
+  return success();
+}
+
 static Value createLinalgBodyCalculationForElementwiseOp(
     Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
     ConversionPatternRewriter &rewriter) {
@@ -139,17 +183,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     auto shiftVal = cast<tosa::MulOp>(op).getShift();
     DenseElementsAttr shiftElem;
     if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
-      (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
-      return nullptr;
+      llvm_unreachable("shift value of mul not found");
     }
 
     int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
-        (void)rewriter.notifyMatchFailure(op,
-                                          "Cannot have shift value for float");
-        return nullptr;
+        llvm_unreachable("Cannot have shift value for float");
       }
       return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
@@ -196,16 +237,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
     FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
     if (failed(maybeInZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "input1 zero point cannot be statically determined");
-      return nullptr;
+      llvm_unreachable("input1 zero point cannot be statically determined");
     }
 
     FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
     if (failed(maybeOutZp)) {
-      (void)rewriter.notifyMatchFailure(
-          op, "output zero point cannot be statically determined");
-      return nullptr;
+      llvm_unreachable("output zero point cannot be statically determined");
     }
 
     int64_t inZp = *maybeInZp;
@@ -548,10 +585,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::CastOp>(op)) {
     Type srcTy = elementTy;
     Type dstTy = resultTypes.front();
-    if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
-      (void)rewriter.notifyMatchFailure(op, "unsupported type");
-      return nullptr;
-    }
+    if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat())
+      llvm_unreachable("unsupported type");
 
     bool bitExtend =
         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
@@ -706,8 +741,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
     }
   }
 
-  (void)rewriter.notifyMatchFailure(
-      op, "unhandled op for linalg body calculation for elementwise op");
+  llvm_unreachable(
+      "unhandled op for linalg body calculation for elementwise op");
   return nullptr;
 }
 
@@ -930,17 +965,11 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
   });
 }
 
-static LogicalResult
-emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
-                           Operation *operation, ValueRange operands,
-                           ArrayRef<OpFoldResult> targetShape,
-                           const TypeConverter &converter) {
+static LogicalResult emitElementwiseComputation(
+    ConversionPatternRewriter &rewriter, Location loc, Operation *operation,
+    ValueRange operands, ArrayRef<OpFoldResult> targetShape,
+    const TypeConverter &converter, RankedTensorType resultType) {
   // Generate output tensor
-  auto resultType = cast_or_null<RankedTensorType>(
-      converter.convertType(operation->getResultTypes().front()));
-  if (!resultType) {
-    return rewriter.notifyMatchFailure(operation, "failed to convert type");
-  }
   Value outputTensor = rewriter.create<tensor::EmptyOp>(
       loc, targetShape, resultType.getElementType());
 
@@ -967,7 +996,6 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
   affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
 
   // Emit 'linalg.generic' op
-  bool encounteredError = false;
   auto linalgOp = rewriter.create<linalg::GenericOp>(
       loc, outputTensor.getType(), operands, outputTensor, affineMaps,
       getNParallelLoopsAttrs(rank),
@@ -975,15 +1003,10 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
         Value opResult = createLinalgBodyCalculationForElementwiseOp(
             operation, blockArgs.take_front(operation->getNumOperands()),
             {resultType.getElementType()}, rewriter);
-        if (!opResult) {
-          encounteredError = true;
-          return;
-        }
+        assert(opResult &&
+               "unable to create linalg.generic body for elementwise op");
         opBuilder.create<linalg::YieldOp>(loc, opResult);
       });
-  if (encounteredError)
-    return rewriter.notifyMatchFailure(
-        operation, "unable to create linalg.generic body for elementwise op");
 
   // Cast 'linalg.generic' result into original result type if needed
   auto castResult = rewriter.createOrFold<tensor::CastOp>(
@@ -1008,13 +1031,20 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
                                  ConversionPatternRewriter &rewriter,
                                  const TypeConverter &converter) {
 
-  // Collect op properties
+  // Check if operation is supported.
   assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
   assert(operation->getNumOperands() >= 1 &&
          "elementwise op expects at least 1 operand");
   if (!operandsAndResultsRanked(operation))
     return rewriter.notifyMatchFailure(operation,
                                        "Unranked tensors not supported");
+  auto resultType = cast_or_null<RankedTensorType>(
+      converter.convertType(operation->getResultTypes().front()));
+  if (!resultType) {
+    return rewriter.notifyMatchFailure(operation, "failed to convert type");
+  }
+  if (failed(isSupportedElementwiseOperation(rewriter, operation, resultType)))
+    return failure();
 
   // Lower operation
   IndexPool indexPool;
@@ -1026,7 +1056,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
       broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
                                  targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
-                                    targetShape, converter);
+                                    targetShape, converter, resultType);
 }
 
 // Returns the constant initial value for a given reduction operation. The
@@ -1126,7 +1156,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
     return rewriter.create<arith::OrIOp>(loc, args);
 
-  return {};
+  llvm_unreachable("unhandled reduction op");
 }
 
 // Performs the match and rewrite for reduction operations. This includes
@@ -1142,6 +1172,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
 
   auto elementTy = resultTy.getElementType();
+  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+  if (!fillValueAttr)
+    return rewriter.notifyMatchFailure(
+        op, "No initial value found for reduction operation");
   Value input = op->getOperand(0);
 
   SmallVector<int64_t> reduceShape;
@@ -1164,11 +1198,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
                                    dynDims)
           .getResult();
 
-  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
-  if (!fillValueAttr)
-    return rewriter.notifyMatchFailure(
-        op, "No initial value found for reduction operation");
-
   auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
   auto filledTensor = rewriter
                           .create<linalg::FillOp>(loc, ValueRange{fillValue},
@@ -1212,7 +1241,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     }
   }
 
-  bool didEncounterError = false;
   linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
       loc, inputs, outputs, axis,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
@@ -1220,8 +1248,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
             blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
         auto result = createLinalgBodyCalculationForReduceOp(
             op, binaryArgs, elementTy, rewriter);
-        if (result)
-          didEncounterError = true;
+        assert(result && "could not create reduction body");
 
         SmallVector<Value> resultsToYield;
         if (isNanIgnoreMode) {
@@ -1247,10 +1274,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
         nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
       });
 
-  if (!didEncounterError)
-    return rewriter.notifyMatchFailure(
-        op, "unable to create linalg.generic body for reduce op");
-
   if (isNanIgnoreMode) {
     // Materialize a check to see whether we encountered any non-NaN values, if
     // we didn't we need to select a tensor of NaNs since the result will just
@@ -1358,13 +1381,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     if (!isa<IntegerType>(inputTy.getElementType()))
       return rewriter.notifyMatchFailure(op, "only support integer type");
 
-    SmallVector<Value> dynDims;
-    for (int i = 0; i < outputTy.getRank(); i++) {
-      if (outputTy.isDynamicDim(i)) {
-        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
-      }
-    }
-
     // The shift and multiplier values.
     DenseElementsAttr shiftElems;
     if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
@@ -1376,6 +1392,21 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires constant multiplier input values");
 
+    if (failed(op.getInputZeroPoint()))
+      return rewriter.notifyMatchFailure(
+          op, "input zero point cannot be statically determined");
+
+    if (failed(op.getOutputZeroPoint()))
+      return rewriter.notifyMatchFailure(
+          op, "output zero point cannot be statically determined");
+
+    SmallVector<Value> dynDims;
+    for (int i = 0; i < outputTy.getRank(); i++) {
+      if (outputTy.isDynamicDim(i)) {
+        dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
+      }
+    }
+
     llvm::SmallVector<int8_t> shiftValues =
         llvm::to_vector(shiftElems.getValues<int8_t>());
     // explicit cast is required here
@@ -1473,23 +1504,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
           auto inputZp = createConstOpFromZpVal<int32_t>(
               op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
               nestedBuilder);
-
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
-
           auto outputZp = createConstOpFromZpVal<int32_t>(
               op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
 
@@ -1783,6 +1801,15 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       return rewriter.notifyMatchFailure(
           op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
 
+    SmallVector<int64_t> scale, offset, border;
+    if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
+        !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
+        !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
+      return rewriter.notifyMatchFailure(
+          op, "tosa.resize scale/offset/border should have compile time "
+              "constant values.");
+    }
+
     SmallVector<AffineMap, 2> affineMaps = {
         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
     auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
@@ -1810,15 +1837,6 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
       Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
       Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
 
-      SmallVector<int64_t> scale, offset, border;
-      if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) ||
-          !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) ||
-          !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) {
-        return rewriter.notifyMatchFailure(
-            op, "tosa.resize scale/offset/border should have compile time "
-                "constant values.");
-      }
-
       Value yScaleN, yScaleD, xScaleN, xScaleD;
       yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
       yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
@@ -2204,6 +2222,9 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
     auto inputTy = cast<ShapedType>(input.getType());
     auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
     auto inElementTy = inputTy.getElementType();
+    if (!isa<IntegerType, FloatType>(inElementTy))
+      return rewriter.notifyMatchFailure(
+          argmaxOp, "unsupported tosa.argmax element type");
     auto outElementTy = resultTy.getElementType();
     int axis = argmaxOp.getAxis();
     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
@@ -2213,6 +2234,12 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
           argmaxOp,
           "tosa.arg_max to linalg.* requires integer-like result type");
 
+    auto fillValueMaxAttr =
+        createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
+    if (!fillValueMaxAttr)
+      return rewriter.notifyMatchFailure(
+          argmaxOp, "unsupported tosa.argmax element type");
+
     SmallVector<Value> dynDims;
     for (int i = 0; i < inputTy.getRank(); i++) {
       if (inputTy.isDynamicDim(i) && i != axis) {
@@ -2238,12 +2265,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
                               .create<tensor::EmptyOp>(loc, resultTy.getShape(),
                                                        inElementTy, dynDims)
                               .getResult();
-    auto fillValueMaxAttr =
-        createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
-
-    if (!fillValueMaxAttr)
-      return rewriter.notifyMatchFailure(
-          argmaxOp, "unsupported tosa.argmax element type");
 
     auto fillValueMax =
         rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
@@ -2267,7 +2288,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
         dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
     }
 
-    bool didEncounterError = false;
     auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs},
                                              rewriter.getContext());
     auto linalgOp = rewriter.create<linalg::GenericOp>(
@@ -2305,8 +2325,7 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
             predicate = rewriter.create<arith::CmpIOp>(
                 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
           } else {
-            didEncounterError = true;
-            return;
+            llvm_unreachable("unsupported tosa.argmax element type");
           }
 
           auto resultMax = rewriter.create<arith::SelectOp>(
@@ -2317,10 +2336,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });
 
-    if (didEncounterError)
-      return rewriter.notifyMatchFailure(
-          argmaxOp, "unsupported tosa.argmax element type");
-
     rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
     return success();
   }
@@ -2416,6 +2431,15 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
     auto tableElementTy = tableTy.getElementType();
     auto resultElementTy = resultTy.getElementType();
 
+    bool isI8_8_8 = inputElementTy.isInteger(8) &&
+                    tableElementTy.isInteger(8) && resultElementTy.isInteger(8);
+    bool isI16_16_32 = inputElementTy.isInteger(16) &&
+                       tableElementTy.isInteger(16) &&
+                       resultElementTy.isInteger(32);
+    if (!isI8_8_8 && !isI16_16_32)
+      return rewriter.notifyMatchFailure(
+          op, "unable to create body for tosa.table op");
+
     SmallVector<Value> dynDims;
     for (int i = 0; i < resultTy.getRank(); ++i) {
       if (inputTy.isDynamicDim(i)) {
@@ -2446,8 +2470,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
 
       auto inputValue = block->getArgument(0);
       rewriter.setInsertionPointToStart(block);
-      if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
-          resultElementTy.isInteger(8)) {
+      if (isI8_8_8) {
         Value index = rewriter.create<arith::IndexCastOp>(
             loc, rewriter.getIndexType(), inputValue);
         Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
@@ -2459,8 +2482,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
         return success();
       }
 
-      if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
-          resultElementTy.isInteger(32)) {
+      if (isI16_16_32) {
         Value extend = rewriter.create<arith::ExtSIOp>(
             loc, rewriter.getI32Type(), inputValue);
 
@@ -2516,8 +2538,7 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
       }
     }
 
-    return rewriter.notifyMatchFailure(
-        op, "unable to create body for tosa.table op");
+    llvm_unreachable("unable to create body for tosa.table op");
   }
 };
 



More information about the Mlir-commits mailing list