[Mlir-commits] [mlir] Remove FullyConnectedOp from TOSA Dialect (PR #126152)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 6 15:38:32 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

<details>
<summary>Changes</summary>

This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms.

This is part of the TOSA v1.0 alignment effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708

---

Patch is 38.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126152.diff


17 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (-9) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (-26) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h (-1) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (-79) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+3-90) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (-1) 
- (removed) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (-164) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp (-1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-14) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (-71) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (-20) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (-7) 
- (removed) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (-76) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (-45) 
- (removed) mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir (-36) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f492bad78e775ca..862d98ad436a659 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
                                   outputShape, acc_type);
   }]>;
 
-// The tosa.fully_connected op has its own builder as it does not have
-// strides/dilation/padding.
-def Tosa_FCOpQuantInfoBuilder : OpBuilder<
-  (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
-  [{
-    buildFCOpWithQuantInfo($_builder, $_state, outputType,
-                           input, weight, bias);
-  }]>;
-
 // The tosa.matmul op is also intended to be generated where a fully_connected
 // op must be constructed where the weight is not a constant. In this case,
 // the fully_connected op must be expressed using matmul.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 98bcbca3b02fa12..74b93bcb371fa47 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// Operator: fully_connected
-//===----------------------------------------------------------------------===//
-def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
-  let summary = "Fully Connected operator";
-
-  let description = [{
-    Performs a fully connected network.
-  }];
-
-  let arguments = (ins
-    Tosa_Tensor2D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
-    Tosa_Tensor1D:$bias,
-    OptionalAttr<I32Attr>:$input_zp,
-    OptionalAttr<I32Attr>:$weight_zp
-  );
-
-  let results = (outs
-    Tosa_Tensor2D:$output
-  );
-
-  let builders = [Tosa_FCOpQuantInfoBuilder];
-  let hasVerifier = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // Operator: matmul
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 7aa1f72ec6e1792..6d5cf9182736f0a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
                                 "number">;
 
 // For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
-// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
+// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
 def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
                              Tosa_QuantizedInt, AnyFloat]>;
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5c..565970367e5dc56 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -26,7 +26,6 @@ namespace tosa {
 
 // Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
 // The rewrites can be selectively added to a conversion pass.
-void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
 void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
                                         RewritePatternSet &patterns);
 void populateTosaDecomposeDepthwise(MLIRContext *ctx,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6321cb6087394a3..a8fd536dd254842 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
   }
 };
 
-class FullyConnectedConverter
-    : public OpConversionPattern<tosa::FullyConnectedOp> {
-public:
-  using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Location loc = op.getLoc();
-    auto outputTy = cast<ShapedType>(op.getType());
-    auto input = op.getInput();
-    auto inputTy = cast<ShapedType>(input.getType());
-
-    auto bias = op.getBias();
-
-    auto weight = op.getWeight();
-    auto weightTy = cast<ShapedType>(weight.getType());
-    auto weightShape = weightTy.getShape();
-
-    auto outputETy = outputTy.getElementType();
-
-    SmallVector<Value> dynDims;
-    dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
-
-    if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
-      dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
-    }
-
-    if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
-      dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
-    }
-
-    SmallVector<Value> filteredDims = condenseValues(dynDims);
-
-    SmallVector<int64_t> permutation = {1, 0};
-    auto permutationAttr = rewriter.getI64TensorAttr(permutation);
-    Value permutationValue =
-        rewriter.create<arith::ConstantOp>(loc, permutationAttr);
-
-    SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
-    Type newWeightTy =
-        RankedTensorType::get(newWeightShape, weightTy.getElementType());
-
-    Value transposedWeight = rewriter.create<tosa::TransposeOp>(
-        loc, newWeightTy, weight, permutationValue);
-
-    Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, outputTy.getShape(), outputETy, filteredDims);
-
-    Value broadcastBias =
-        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
-
-    if (!op.getInputZp() && !op.getWeightZp()) {
-      Value matmul = rewriter
-                         .create<linalg::MatmulOp>(
-                             loc, TypeRange{op.getType()},
-                             ValueRange{input, transposedWeight}, broadcastBias)
-                         ->getResult(0);
-
-      rewriter.replaceOp(op, matmul);
-      return success();
-    }
-
-    auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
-    auto outputZp =
-        rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
-    Value matmul =
-        rewriter
-            .create<linalg::QuantizedMatmulOp>(
-                loc, TypeRange{op.getType()},
-                ValueRange{input, transposedWeight, inputZp, outputZp},
-                broadcastBias)
-            ->getResult(0);
-
-    rewriter.replaceOp(op, matmul);
-    return success();
-  }
-};
-
 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
@@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       DepthwiseConvConverter,
       MatMulConverter,
       AvgPool2dConverter,
-      FullyConnectedConverter,
       TransposeConverter
   >(patterns->getContext());
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 7d943b3779fb02e..80df3908991dde2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -62,7 +62,6 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::MaxPool2dOp>();
     target.addIllegalOp<tosa::AvgPool2dOp>();
     target.addIllegalOp<tosa::MatMulOp>();
-    target.addIllegalOp<tosa::FullyConnectedOp>();
     target.addIllegalOp<tosa::TransposeOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..be54319e64fde68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -540,26 +540,9 @@ static void buildTransConvOpWithQuantInfo(
   }
 }
 
-/// The tosa.fully_connected op has its own builder as it does not have
-/// strides/dilation/padding.
-static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
-                                   Type outputType, Value input, Value weight,
-                                   Value bias) {
-
-  result.addOperands({input, weight, bias});
-  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
-  if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
-  }
-}
-
-/// The tosa.matmul op is also intended to be generated where a
-/// fully_connected op must be constructed where the weight is not a constant.
-/// In this case, the fully_connected op must be expressed using matmul.
+/// The tosa.matmul op is also intended to be generated where a fully_connected
+/// op must be constructed where the weight is not a constant. In this case,
+/// the fully_connected op must be expressed using matmul.
 /// TODO: Add link to the leglization document explaining this.
 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
                                        OperationState &result, Type outputType,
@@ -863,76 +846,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return succeeded(verifyCompatibleShape(l[0], r[0]));
 }
 
-LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
-    MLIRContext *context, ::std::optional<Location> location,
-    FullyConnectedOp::Adaptor adaptor,
-    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ShapeAdaptor inputShape(adaptor.getInput().getType());
-  ShapeAdaptor weightShape(adaptor.getWeight().getType());
-  ShapeAdaptor biasShape(adaptor.getBias().getType());
-
-  // All shapes are dynamic.
-  SmallVector<int64_t> outShape;
-  outShape.resize(2, ShapedType::kDynamic);
-
-  if (inputShape.hasRank()) {
-    outShape[0] = inputShape.getDimSize(0);
-  }
-
-  if (weightShape.hasRank()) {
-    outShape[1] = weightShape.getDimSize(0);
-  }
-
-  if (biasShape.hasRank()) {
-    outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
-                                                      : outShape[1];
-  }
-
-  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
-  return success();
-}
-
-LogicalResult FullyConnectedOp::verify() {
-  // All TOSA conv ops have an input() and weight().
-  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
-
-  RankedTensorType weightType =
-      llvm::dyn_cast<RankedTensorType>(getWeight().getType());
-
-  // Must be ranked tensor types
-  if (!inputType) {
-    emitOpError("expect a ranked tensor for input, got ") << getInput();
-    return failure();
-  }
-  if (!weightType) {
-    emitOpError("expect a ranked tensor for weight, got ") << getWeight();
-    return failure();
-  }
-
-  auto inputEType = inputType.getElementType();
-  auto weightEType = weightType.getElementType();
-
-  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
-  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
-
-  // Either both must be quantized or both unquantized.
-  if (inputIsQuant != weightIsQuant) {
-    emitOpError(
-        "expect both input and weight to be float or not together, got ")
-        << inputEType << " and " << weightEType;
-    return failure();
-  }
-
-  // Quantized type must have constructed the quantizationattr, and unquantized
-  // types should not have a quantizationattr.
-  if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
-    emitOpError("input zero point is required for quantized type, and not "
-                "allowed for float type");
-    return failure();
-  }
-  return success();
-}
-
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     MatMulOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 5b0f5ec4cd5687b..9c3345b617cc5f3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect_library(MLIRTosaTransforms
   TosaDecomposeTransposeConv.cpp
-  TosaDecomposeConv2D.cpp
   TosaDecomposeDepthwise.cpp
   TosaFolders.cpp
   TosaInferShapes.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
deleted file mode 100644
index 4eba89b59bbd79f..000000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
-// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
-  return to_vector(llvm::map_range(shape, [](int64_t dim) {
-    return ShapedType::isDynamic(dim) ? -1 : dim;
-  }));
-}
-
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
-  explicit Conv2DIsFullyConnected(MLIRContext *context)
-      : OpRewritePattern(context) {}
-
-  LogicalResult matchAndRewrite(tosa::Conv2DOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getInput();
-    Value weight = op.getWeight();
-    ShapedType inputType = cast<ShapedType>(input.getType());
-    ShapedType weightType = cast<ShapedType>(weight.getType());
-    ShapedType resultType = cast<ShapedType>(op.getType());
-
-    auto numDynamic =
-        llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
-    if (numDynamic > 1)
-      return rewriter.notifyMatchFailure(
-          op, "at most one dim in input may be dynamic");
-    if (!weightType.hasRank())
-      return rewriter.notifyMatchFailure(op, "unranked weight input");
-
-    if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
-      return failure();
-
-    // Only works for a 1x1 kernel.
-    ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1)
-      return failure();
-
-    llvm::ArrayRef<int64_t> padAttr = op.getPad();
-    llvm::SmallVector<int64_t> pad(8, 0);
-    for (const auto &it : llvm::enumerate(padAttr))
-      pad[it.index() + 2] = it.value();
-
-    Type inputETy = inputType.getElementType();
-    if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
-      auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-      if (failed(failureOrMaybeZps))
-        return failure();
-
-      auto maybeZps = failureOrMaybeZps.value();
-
-      Attribute zeroAttr =
-          maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp)
-                   : rewriter.getZeroAttr(inputETy);
-
-      llvm::SmallVector<int64_t> newShape(inputType.getShape());
-
-      for (int i = 0, s = newShape.size(); i < s; ++i) {
-        if (newShape[i] != ShapedType::kDynamic) {
-          newShape[i] += pad[i * 2] + pad[i * 2 + 1];
-        }
-      }
-
-      Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
-
-      auto padTy = RankedTensorType::get({}, inputETy);
-      auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
-      Value padVal =
-          rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
-      inputType = RankedTensorType::get(newShape, inputETy);
-      input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
-                                           padSizeVal, padVal);
-    }
-
-    // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-    int64_t combined = ShapedType::kDynamic;
-    if (numDynamic == 0)
-      combined = inputShape[0] * inputShape[1] * inputShape[2];
-    llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
-    auto revisedInputShapeType =
-        RankedTensorType::get(revisedInputShape, inputType.getElementType());
-    auto reshapedInput = rewriter
-                             .create<tosa::ReshapeOp>(
-                                 op.getLoc(), revisedInputShapeType, input,
-                                 rewriter.getDenseI64ArrayAttr(
-                                     convertFromMlirShape(revisedInputShape)))
-                             .getResult();
-
-    // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
-    llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
-                                                     weightShape[3]};
-    auto revisedWeightShapeType = RankedTensorType::get(
-        revisedWeightShape,
-        dyn_cast<RankedTensorType>(weight.getType()).getElementType());
-    auto reshapedWeight = rewriter
-                              .create<tosa::ReshapeOp>(
-                                  op.getLoc(), revisedWeightShapeType, weight,
-                                  rewriter.getDenseI64ArrayAttr(
-                                      convertFromMlirShape(revisedWeightShape)))
-                              .getResult();
-
-    // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
-    auto fullyConnectedShapeType =
-        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (failed(failureOrMaybeZps))
-      return failure();
-
-    auto maybeZps = failureOrMaybeZps.value();
-    Value fullyConnectedValue;
-    if (maybeZps) {
-      fullyConnectedValue =
-          rewriter
-              .create<tosa::FullyConnectedOp>(
-                  op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.getBias(),
-                  rewriter.getI32IntegerAttr(maybeZps->inputZp),
-                  rewriter.getI32IntegerAttr(maybeZps->weightZp))
-              .getResult();
-    } else {
-      fullyConnectedValue = rewriter
-                                .create<tosa::FullyConnectedOp>(
-                                    op.getLoc(), fullyConnectedShapeType,
-                                    reshapedInput, reshapedWeight, op.getBias())
-                                .getResult();
-    }
-
-    // Reshape output to [N, IH, IW, OC].
-    llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
-                                              inputShape[2], weightShape[0]};
-    rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
-        op, resultType, fullyConnectedValue,
-        rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
-    return success();
-  }
-};
-
-} // namespace
-
-void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
-                                             RewritePatternSet &patterns) {
-  patterns.add<Conv2DIsFullyConnected>(ctx);
-}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
index 603185e48aa94e1..ffa2ea3d0629f37 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
@@ -38,7 +38,6 @@ struct TosaOptionalDecompositions
     RewritePatternSet patterns(ctx);
     auto func = getOperation();
 
-    mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/126152


More information about the Mlir-commits mailing list