[Mlir-commits] [mlir] Remove FullyConnectedOp from TOSA Dialect (PR #126152)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 6 15:37:56 PST 2025
https://github.com/Jerry-Ge created https://github.com/llvm/llvm-project/pull/126152
This patch removes FullyConncected Operator from the TOSA Dialect and all associated tests and transforms.
This is part of the TOSA v1.0 alignment effort: https://discourse.llvm.org/t/rfc-tosa-dialect-increment-to-v1-0/83708
>From 54f463210aebed0a503daf9d63a2a94774a7db51 Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Thu, 1 Aug 2024 20:10:34 +0000
Subject: [PATCH] Remove FullyConnectedOp from TOSA Dialect
This patch removes FullyConncected Operator from the TOSA Dialect
and all associated tests and transforms.
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: Ib8c928cb21daf325f00cdad302680af2d7c13da5
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 9 -
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 26 ---
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 2 +-
.../mlir/Dialect/Tosa/Transforms/Passes.h | 1 -
.../TosaToLinalg/TosaToLinalgNamed.cpp | 79 ---------
.../TosaToLinalg/TosaToLinalgNamedPass.cpp | 1 -
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 93 +---------
.../Dialect/Tosa/Transforms/CMakeLists.txt | 1 -
.../Tosa/Transforms/TosaDecomposeConv2D.cpp | 164 ------------------
.../Transforms/TosaOptionalDecompositions.cpp | 1 -
.../Tosa/Transforms/TosaValidation.cpp | 14 --
.../TosaToLinalg/tosa-to-linalg-named.mlir | 71 --------
mlir/test/Dialect/Tosa/invalid.mlir | 20 ---
mlir/test/Dialect/Tosa/ops.mlir | 7 -
.../Dialect/Tosa/tosa-decompose-conv2d.mlir | 76 --------
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 45 -----
.../Tosa/CPU/test-fully-connected.mlir | 36 ----
17 files changed, 4 insertions(+), 642 deletions(-)
delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
delete mode 100644 mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
delete mode 100644 mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f492bad78e775ca..862d98ad436a659 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -150,15 +150,6 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
outputShape, acc_type);
}]>;
-// The tosa.fully_connected op has its own builder as it does not have
-// strides/dilation/padding.
-def Tosa_FCOpQuantInfoBuilder : OpBuilder<
- (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias),
- [{
- buildFCOpWithQuantInfo($_builder, $_state, outputType,
- input, weight, bias);
- }]>;
-
// The tosa.matmul op is also intended to be generated where a fully_connected
// op must be constructed where the weight is not a constant. In this case,
// the fully_connected op must be expressed using matmul.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 98bcbca3b02fa12..74b93bcb371fa47 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -224,32 +224,6 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
}];
}
-//===----------------------------------------------------------------------===//
-// Operator: fully_connected
-//===----------------------------------------------------------------------===//
-def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
- let summary = "Fully Connected operator";
-
- let description = [{
- Performs a fully connected network.
- }];
-
- let arguments = (ins
- Tosa_Tensor2D:$input,
- TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
- Tosa_Tensor1D:$bias,
- OptionalAttr<I32Attr>:$input_zp,
- OptionalAttr<I32Attr>:$weight_zp
- );
-
- let results = (outs
- Tosa_Tensor2D:$output
- );
-
- let builders = [Tosa_FCOpQuantInfoBuilder];
- let hasVerifier = 1;
-}
-
//===----------------------------------------------------------------------===//
// Operator: matmul
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 7aa1f72ec6e1792..6d5cf9182736f0a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -81,7 +81,7 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
-// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
+// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
Tosa_QuantizedInt, AnyFloat]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
index 1f9522b51a4cf5c..565970367e5dc56 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
@@ -26,7 +26,6 @@ namespace tosa {
// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops.
// The rewrites can be selectively added to a conversion pass.
-void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns);
void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6321cb6087394a3..a8fd536dd254842 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -607,84 +607,6 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
}
};
-class FullyConnectedConverter
- : public OpConversionPattern<tosa::FullyConnectedOp> {
-public:
- using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const final {
- Location loc = op.getLoc();
- auto outputTy = cast<ShapedType>(op.getType());
- auto input = op.getInput();
- auto inputTy = cast<ShapedType>(input.getType());
-
- auto bias = op.getBias();
-
- auto weight = op.getWeight();
- auto weightTy = cast<ShapedType>(weight.getType());
- auto weightShape = weightTy.getShape();
-
- auto outputETy = outputTy.getElementType();
-
- SmallVector<Value> dynDims;
- dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
-
- if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
- dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
- }
-
- if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
- dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
- }
-
- SmallVector<Value> filteredDims = condenseValues(dynDims);
-
- SmallVector<int64_t> permutation = {1, 0};
- auto permutationAttr = rewriter.getI64TensorAttr(permutation);
- Value permutationValue =
- rewriter.create<arith::ConstantOp>(loc, permutationAttr);
-
- SmallVector<int64_t> newWeightShape = {weightShape[1], weightShape[0]};
- Type newWeightTy =
- RankedTensorType::get(newWeightShape, weightTy.getElementType());
-
- Value transposedWeight = rewriter.create<tosa::TransposeOp>(
- loc, newWeightTy, weight, permutationValue);
-
- Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputETy, filteredDims);
-
- Value broadcastBias =
- linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
-
- if (!op.getInputZp() && !op.getWeightZp()) {
- Value matmul = rewriter
- .create<linalg::MatmulOp>(
- loc, TypeRange{op.getType()},
- ValueRange{input, transposedWeight}, broadcastBias)
- ->getResult(0);
-
- rewriter.replaceOp(op, matmul);
- return success();
- }
-
- auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
- auto outputZp =
- rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr());
- Value matmul =
- rewriter
- .create<linalg::QuantizedMatmulOp>(
- loc, TypeRange{op.getType()},
- ValueRange{input, transposedWeight, inputZp, outputZp},
- broadcastBias)
- ->getResult(0);
-
- rewriter.replaceOp(op, matmul);
- return success();
- }
-};
-
class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -1090,7 +1012,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
DepthwiseConvConverter,
MatMulConverter,
AvgPool2dConverter,
- FullyConnectedConverter,
TransposeConverter
>(patterns->getContext());
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 7d943b3779fb02e..80df3908991dde2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -62,7 +62,6 @@ struct TosaToLinalgNamed
target.addIllegalOp<tosa::MaxPool2dOp>();
target.addIllegalOp<tosa::AvgPool2dOp>();
target.addIllegalOp<tosa::MatMulOp>();
- target.addIllegalOp<tosa::FullyConnectedOp>();
target.addIllegalOp<tosa::TransposeOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..be54319e64fde68 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -540,26 +540,9 @@ static void buildTransConvOpWithQuantInfo(
}
}
-/// The tosa.fully_connected op has its own builder as it does not have
-/// strides/dilation/padding.
-static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
- Type outputType, Value input, Value weight,
- Value bias) {
-
- result.addOperands({input, weight, bias});
- auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
- if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
- result.addTypes(
- buildConvOpResultTypeInfo(builder, outputType, input, weight));
- } else {
- result.addTypes(outputType);
- }
-}
-
-/// The tosa.matmul op is also intended to be generated where a
-/// fully_connected op must be constructed where the weight is not a constant.
-/// In this case, the fully_connected op must be expressed using matmul.
+/// The tosa.matmul op is also intended to be generated where a fully_connected
+/// op must be constructed where the weight is not a constant. In this case,
+/// the fully_connected op must be expressed using matmul.
/// TODO: Add link to the leglization document explaining this.
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
@@ -863,76 +846,6 @@ bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return succeeded(verifyCompatibleShape(l[0], r[0]));
}
-LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
- MLIRContext *context, ::std::optional<Location> location,
- FullyConnectedOp::Adaptor adaptor,
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- ShapeAdaptor inputShape(adaptor.getInput().getType());
- ShapeAdaptor weightShape(adaptor.getWeight().getType());
- ShapeAdaptor biasShape(adaptor.getBias().getType());
-
- // All shapes are dynamic.
- SmallVector<int64_t> outShape;
- outShape.resize(2, ShapedType::kDynamic);
-
- if (inputShape.hasRank()) {
- outShape[0] = inputShape.getDimSize(0);
- }
-
- if (weightShape.hasRank()) {
- outShape[1] = weightShape.getDimSize(0);
- }
-
- if (biasShape.hasRank()) {
- outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
- : outShape[1];
- }
-
- inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
- return success();
-}
-
-LogicalResult FullyConnectedOp::verify() {
- // All TOSA conv ops have an input() and weight().
- auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
-
- RankedTensorType weightType =
- llvm::dyn_cast<RankedTensorType>(getWeight().getType());
-
- // Must be ranked tensor types
- if (!inputType) {
- emitOpError("expect a ranked tensor for input, got ") << getInput();
- return failure();
- }
- if (!weightType) {
- emitOpError("expect a ranked tensor for weight, got ") << getWeight();
- return failure();
- }
-
- auto inputEType = inputType.getElementType();
- auto weightEType = weightType.getElementType();
-
- bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
- bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
-
- // Either both must be quantized or both unquantized.
- if (inputIsQuant != weightIsQuant) {
- emitOpError(
- "expect both input and weight to be float or not together, got ")
- << inputEType << " and " << weightEType;
- return failure();
- }
-
- // Quantized type must have constructed the quantizationattr, and unquantized
- // types should not have a quantizationattr.
- if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
- emitOpError("input zero point is required for quantized type, and not "
- "allowed for float type");
- return failure();
- }
- return success();
-}
-
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
MatMulOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index 5b0f5ec4cd5687b..9c3345b617cc5f3 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeTransposeConv.cpp
- TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFolders.cpp
TosaInferShapes.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
deleted file mode 100644
index 4eba89b59bbd79f..000000000000000
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ /dev/null
@@ -1,164 +0,0 @@
-//===- TosaDecomposeConv2D.cpp --------------------------------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically
-// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Tosa/IR/TosaOps.h"
-#include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-
-using namespace mlir;
-using namespace mlir::tosa;
-
-namespace {
-
-SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
- return ShapedType::isDynamic(dim) ? -1 : dim;
- }));
-}
-
-struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
- explicit Conv2DIsFullyConnected(MLIRContext *context)
- : OpRewritePattern(context) {}
-
- LogicalResult matchAndRewrite(tosa::Conv2DOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getInput();
- Value weight = op.getWeight();
- ShapedType inputType = cast<ShapedType>(input.getType());
- ShapedType weightType = cast<ShapedType>(weight.getType());
- ShapedType resultType = cast<ShapedType>(op.getType());
-
- auto numDynamic =
- llvm::count_if(inputType.getShape(), ShapedType::isDynamic);
- if (numDynamic > 1)
- return rewriter.notifyMatchFailure(
- op, "at most one dim in input may be dynamic");
- if (!weightType.hasRank())
- return rewriter.notifyMatchFailure(op, "unranked weight input");
-
- if (!llvm::all_of(op.getStride(), [](int64_t v) { return v == 1; }))
- return failure();
-
- // Only works for a 1x1 kernel.
- ArrayRef<int64_t> weightShape = weightType.getShape();
- if (weightShape[1] != 1 || weightShape[2] != 1)
- return failure();
-
- llvm::ArrayRef<int64_t> padAttr = op.getPad();
- llvm::SmallVector<int64_t> pad(8, 0);
- for (const auto &it : llvm::enumerate(padAttr))
- pad[it.index() + 2] = it.value();
-
- Type inputETy = inputType.getElementType();
- if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
- auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
- if (failed(failureOrMaybeZps))
- return failure();
-
- auto maybeZps = failureOrMaybeZps.value();
-
- Attribute zeroAttr =
- maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp)
- : rewriter.getZeroAttr(inputETy);
-
- llvm::SmallVector<int64_t> newShape(inputType.getShape());
-
- for (int i = 0, s = newShape.size(); i < s; ++i) {
- if (newShape[i] != ShapedType::kDynamic) {
- newShape[i] += pad[i * 2] + pad[i * 2 + 1];
- }
- }
-
- Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
-
- auto padTy = RankedTensorType::get({}, inputETy);
- auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);
- Value padVal =
- rewriter.create<tosa::ConstOp>(op->getLoc(), padTy, padAttr);
- inputType = RankedTensorType::get(newShape, inputETy);
- input = rewriter.create<tosa::PadOp>(op->getLoc(), inputType, input,
- padSizeVal, padVal);
- }
-
- // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t combined = ShapedType::kDynamic;
- if (numDynamic == 0)
- combined = inputShape[0] * inputShape[1] * inputShape[2];
- llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
- auto revisedInputShapeType =
- RankedTensorType::get(revisedInputShape, inputType.getElementType());
- auto reshapedInput = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedInputShapeType, input,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedInputShape)))
- .getResult();
-
- // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC].
- llvm::SmallVector<int64_t, 2> revisedWeightShape{weightShape[0],
- weightShape[3]};
- auto revisedWeightShapeType = RankedTensorType::get(
- revisedWeightShape,
- dyn_cast<RankedTensorType>(weight.getType()).getElementType());
- auto reshapedWeight = rewriter
- .create<tosa::ReshapeOp>(
- op.getLoc(), revisedWeightShapeType, weight,
- rewriter.getDenseI64ArrayAttr(
- convertFromMlirShape(revisedWeightShape)))
- .getResult();
-
- // Perform a fully connected network over the reshaped input and weight.
- llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
- auto fullyConnectedShapeType =
- RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
-
- auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
- if (failed(failureOrMaybeZps))
- return failure();
-
- auto maybeZps = failureOrMaybeZps.value();
- Value fullyConnectedValue;
- if (maybeZps) {
- fullyConnectedValue =
- rewriter
- .create<tosa::FullyConnectedOp>(
- op.getLoc(), fullyConnectedShapeType, reshapedInput,
- reshapedWeight, op.getBias(),
- rewriter.getI32IntegerAttr(maybeZps->inputZp),
- rewriter.getI32IntegerAttr(maybeZps->weightZp))
- .getResult();
- } else {
- fullyConnectedValue = rewriter
- .create<tosa::FullyConnectedOp>(
- op.getLoc(), fullyConnectedShapeType,
- reshapedInput, reshapedWeight, op.getBias())
- .getResult();
- }
-
- // Reshape output to [N, IH, IW, OC].
- llvm::SmallVector<int64_t, 4> outputShape{inputShape[0], inputShape[1],
- inputShape[2], weightShape[0]};
- rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
- op, resultType, fullyConnectedValue,
- rewriter.getDenseI64ArrayAttr(convertFromMlirShape(outputShape)));
- return success();
- }
-};
-
-} // namespace
-
-void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx,
- RewritePatternSet &patterns) {
- patterns.add<Conv2DIsFullyConnected>(ctx);
-}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
index 603185e48aa94e1..ffa2ea3d0629f37 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp
@@ -38,7 +38,6 @@ struct TosaOptionalDecompositions
RewritePatternSet patterns(ctx);
auto func = getOperation();
- mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns);
mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 678bb47935bd20d..7f59ff70d337492 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -61,19 +61,6 @@ static LogicalResult checkConstantOperandTranspose(Operation *op) {
return success();
}
-static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
- if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
- DenseElementsAttr weight;
- if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
- return op->emitOpError("weight of fully_connected is not constant");
-
- DenseElementsAttr bias;
- if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
- return op->emitOpError("bias of fully_connected is not constant");
- }
- return success();
-}
-
struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
@@ -123,7 +110,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandPad);
constCheckers.emplace_back(checkConstantOperandTranspose);
- constCheckers.emplace_back(checkConstantOperandFullyConnected);
}
bool levelCheckKernel(Operation *op, int32_t v,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 87c388b6f5ee30b..a524359b49759ec 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -83,77 +83,6 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// -----
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @fully_connected
-func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
- // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
- // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
-
- // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
- // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
- // CHECK: linalg.yield %[[IN]] : f32
- // CHECK: } -> tensor<5x6xf32>
-
- // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<5x6xf32>) -> tensor<5x6xf32>
-
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
- return %0 : tensor<5x6xf32>
-}
-
-// -----
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @quantized_fully_connected
-func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
- // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
-
- // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
- // CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
- // CHECK: linalg.yield %[[IN]] : i32
- // CHECK: } -> tensor<5x6xi32>
-
- // CHECK: %[[C1:.+]] = arith.constant 1 : i32
- // CHECK: %[[C2:.+]] = arith.constant 2 : i32
- // CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
-
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
- return %0 : tensor<5x6xi32>
-}
-
-// -----
-
-// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @fully_connected_dyn
-func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
- // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
- // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
- // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
-
- // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
- // CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
- // CHECK: linalg.yield %[[IN]] : f32
- // CHECK: } -> tensor<?x6xf32>
-
- // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<?x6xf32>) -> tensor<?x6xf32>
-
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
- return %0 : tensor<?x6xf32>
-}
-
-// -----
-
// CHECK-LABEL: @max_pool
func.func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
// CHECK-DAG: [[CONST:%.+]] = arith.constant -3.40282347E+38
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 006c5bd52a9f669..4cfeaaf80c9cb3c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -314,26 +314,6 @@ func.func @test_transpose_element_type_mismatch(%arg0: tensor<2x3xi32>) -> tenso
// -----
-func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> {
- %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
- // expected-error at +1 {{'tosa.fully_connected' op weight of fully_connected is not constant}}
- %2 = tosa.fully_connected %1, %arg1, %0 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
- return %2 : tensor<273x2xf32>
-}
-
-// -----
-
-func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2xf32>) -> tensor<273x2xf32> {
- %0 = "tosa.const"() {value = dense<[[-0.613216758, -0.63714242, -0.73500061], [0.180762768, 0.773053169, -0.933686495]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
- %1 = tosa.reshape %arg0 {new_shape = array<i64: 273, 3>} : (tensor<13x21x3xf32>) -> tensor<273x3xf32>
- // expected-error at +1 {{'tosa.fully_connected' op bias of fully_connected is not constant}}
- %2 = tosa.fully_connected %1, %0, %arg1 : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32>
- return %2 : tensor<273x2xf32>
-}
-
-// -----
-
func.func @test_reduce_sum_type_mismatch(%arg0 : tensor<2x3x4x5xf32>) -> () {
// expected-error at +2 {{failed to infer returned types}}
// expected-error at +1 {{'tosa.reduce_sum' op inferred type(s) 'tensor<1x3x4x5xf32>' are incompatible with return type(s) of operation 'tensor<1x3x4x5xi32>'}}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index d00230d12aab1b4..a3f3f9161b7e94e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -111,13 +111,6 @@ func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1
return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
}
-// -----
-// CHECK-LABEL: fully_connected
-func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32>
- return %0 : tensor<14x28xf32>
-}
-
// -----
// CHECK-LABEL: test_matmul
func.func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
deleted file mode 100644
index e4a2897908072a6..000000000000000
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ /dev/null
@@ -1,76 +0,0 @@
-// RUN: mlir-opt --split-input-file --tosa-optional-decompositions %s | FileCheck %s
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected
-func.func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> {
- // CHECK-NOT: tosa.conv2d
- // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 400, 2>}
- // CHECK-SAME: -> tensor<400x2xf32>
- // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-SAME: -> tensor<3x2xf32>
- // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
- // CHECK-SAME: -> tensor<400x3xf32>
- // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
- // CHECK-SAME: -> tensor<4x10x10x3xf32>
- // CHECK: return %[[VAR3]]
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32>
- return %0 : tensor<4x10x10x3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_quant
-func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> {
- // CHECK-NOT: tosa.conv2d
- // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 400, 2>}
- // CHECK-SAME: -> tensor<400x2xi8>
- // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-SAME: -> tensor<3x2xi8>
- // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2
- // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32}
- // CHECK-SAME: -> tensor<400x3xi32>
- // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
- // CHECK-SAME: -> tensor<4x10x10x3xi32>
- // CHECK: return %[[VAR3]]
- %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x3xi32>
- return %0 : tensor<4x10x10x3xi32>
-}
-
-// -----
-
-// CHECK-LABEL: func.func @conv_with_dynamic_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x14x14x64xi8>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
-func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
-// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
-// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
-// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
-// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
-// CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32>
-// CHECK: }
- %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x14x14x384xi32>
- return %0 : tensor<?x14x14x384xi32>
-}
-
-// -----
-
-// CHECK-LABEL: @conv2d_as_fully_connected_padded
-func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x12x12x3xi32> {
- // CHECK-DAG: %[[PAD_SHAPE:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
- // CHECK-DAG: %[[PAD_VAL:.+]] = "tosa.const"() <{value = dense<42> : tensor<i8>}
- // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8>
- // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>}
- // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
- // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32}
- // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
- %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
- %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
- %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32>
- return %0 : tensor<4x12x12x3xi32>
-}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 73eabab657f380d..8ffd649de03ee69 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -271,51 +271,6 @@ func.func @test_dynamic_argmax(%arg0 : tensor<2x?xi32>) -> () {
// -----
-// CHECK-LABEL: @test_static_fully_connected
-func.func @test_static_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor<5x4xf32>, %arg2 : tensor<5xf32>) -> () {
- // CHECK: tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor<3x5xf32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x4xf32>, tensor<5x4xf32>, tensor<5xf32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @test_static_input_fully_connected
-func.func @test_static_input_fully_connected(%arg0 : tensor<3x4xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?xf32>) -> () {
- // CHECK: tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x4xf32>, tensor<?x?xf32>, tensor<?xf32>) -> tensor<3x?xf32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x4xf32>, tensor<?x?xf32>, tensor<?xf32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @test_static_weight_fully_connected
-func.func @test_static_weight_fully_connected(%arg0 : tensor<?x?xf32>, %arg1 : tensor<5x4xf32>, %arg2 : tensor<?xf32>) -> () {
- // CHECK: tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x?xf32>, tensor<5x4xf32>, tensor<?xf32>) -> tensor<?x5xf32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x?xf32>, tensor<5x4xf32>, tensor<?xf32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @test_static_bias_fully_connected
-func.func @test_static_bias_fully_connected(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<5xf32>) -> () {
- // CHECK: tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x5xf32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
-// CHECK-LABEL: @test_static_out_fully_connected
-func.func @test_static_out_fully_connected(%arg0 : tensor<3x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<5xf32>) -> () {
- // CHECK: tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<3x5xf32>
- %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<3x?xf32>, tensor<?x?xf32>, tensor<5xf32>) -> tensor<?x?xf32>
- return
-}
-
-// -----
-
// CHECK-LABEL: @test_static_matmul
func.func @test_static_matmul(%arg0 : tensor<2x3x4xi32>, %arg1 : tensor<2x4x5xi32>) -> () {
// CHECK: tosa.matmul %arg0, %arg1 : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32>
diff --git a/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir b/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
deleted file mode 100644
index e599650bfffb9e6..000000000000000
--- a/mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
+++ /dev/null
@@ -1,36 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))" | \
-// RUN: mlir-opt -one-shot-bufferize="bufferize-function-boundaries" -buffer-deallocation-pipeline -test-lower-to-llvm | \
-// RUN: mlir-runner -O3 -e main -entry-point-result=void \
-// RUN: -shared-libs=%mlir_runner_utils \
-// RUN: | FileCheck %s
-
-func.func private @printMemrefF32(tensor<*xf32>)
-
-func.func @main() {
- %A = arith.constant dense<[
- [8.0, 1.0, 6.0],
- [3.0, 5.0, 7.0],
- [4.0, 9.0, 2.0]
- ]> : tensor<3x3xf32>
-
- %B = arith.constant dense<[
- [1.0, 1.0, 1.0],
- [1.0, 1.0, 1.0],
- [1.0, 1.0, 1.0]
- ]> : tensor<3x3xf32>
-
- %C = arith.constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32>
-
- %result = tosa.fully_connected %A, %B, %C : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32>
-
- %result_unranked = tensor.cast %result : tensor<3x3xf32> to tensor<*xf32>
- call @printMemrefF32(%result_unranked) : (tensor<*xf32>) -> ()
- return
-}
-
-// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data =
-// CHECK-NEXT: [
-// CHECK-SAME: [15, 16, 17]
-// CHECK-NEXT: [15, 16, 17]
-// CHECK-NEXT: [15, 16, 17]
-// CHECK-SAME: ]
More information about the Mlir-commits
mailing list