[Mlir-commits] [mlir] [TOSA] Move CreateOpAndInfer into ConversionUtils.h (PR #106122)
Tai Ly
llvmlistbot at llvm.org
Mon Aug 26 12:28:29 PDT 2024
https://github.com/Tai78641 created https://github.com/llvm/llvm-project/pull/106122
This moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h
Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF)
>From 862439a97eeaead5dbce8f5dec4e7574bfe554af Mon Sep 17 00:00:00 2001
From: Tai Ly <tai.ly at arm.com>
Date: Wed, 21 Aug 2024 19:39:37 +0000
Subject: [PATCH] Move CreateOpAndInfer into ConversionUtils.h
This moves CreateOpAndInfer from TF legalize_util.h into
ConversionUtils.h
Renamed to CreateOpAndInferShape so we can upstream this
independently of tensorflow (otherwise a redefinition error
would break TF compile if not upstreamed together with
removal of CreateOpAndInfer in TF)
Signed-off-by: Tai Ly <tai.ly at arm.com>
Change-Id: I53f39ec63f2e3763f8e50c03d1203e8dbed6f1bf
---
.../mlir/Dialect/Tosa/Utils/ConversionUtils.h | 137 ++++++++++++++++++
.../Transforms/TosaDecomposeTransposeConv.cpp | 93 +++---------
.../Dialect/Tosa/Utils/ConversionUtils.cpp | 12 +-
3 files changed, 169 insertions(+), 73 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ceab7d9c628a54..60e7ed1ce2f876 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2);
+LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
+ Value &input2);
+
+namespace {
+
+// Creates a TOSA operation and performs shape inference on the individual
+// op. This allows shape inference during the TFLite to TOSA lowering.
+template <typename TosaOp, typename... Args>
+TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+ Args &&...args) {
+ auto op = builder.create<TosaOp>(result_ty, args...);
+
+ InferShapedTypeOpInterface shapeInterface =
+ dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+ if (!shapeInterface)
+ return op;
+
+ SmallVector<ShapedTypeComponents> returnedShapes;
+ if (shapeInterface
+ .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
+ op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
+ op->getRegions(), returnedShapes)
+ .failed())
+ return op;
+
+ // We need to use the element type of the existing result type to generate
+ // the new result shaped type. This is because rescale can include a cast to
+ // different bit-width types and does not have a TypeAttr to define the
+ // target type.
+ auto result = op->getResult(0);
+ auto predictedShape = returnedShapes[0];
+ auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
+
+ // Compute the knowledge based on the inferred type.
+ auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+ inferredKnowledge.dtype = mlir::cast<ShapedType>(result_ty).getElementType();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
+ if (predictedShape.hasRank()) {
+ for (auto dim : predictedShape.getDims()) {
+ inferredKnowledge.sizes.push_back(dim);
+ }
+ }
+
+ // Compute the new type based on the joined version.
+ auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ Type new_ty =
+ newKnowledge.hasRank
+ ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
+ newKnowledge.dtype)}
+ : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
+ result.setType(new_ty);
+ return op;
+}
+
+} // namespace
+
+// Creates a TOSA operation by:
+// - first equalize ranks for ops with SameOperandsAndResultRank trait
+// - create operator
+// - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+ Args &&...args) {
+ if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+ // op requires same ranks for tensor operands
+ if constexpr (sizeof...(Args) == 2) {
+ auto argX = std::get<0>(std::tie(args...));
+ auto argY = std::get<1>(std::tie(args...));
+ using ArgX = decltype(argX);
+ using ArgY = decltype(argY);
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value>) {
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ if (EqualizeRanks(builder, x, y).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y);
+ }
+ }
+ if constexpr (sizeof...(Args) == 3) {
+ auto argX = std::get<0>(std::tie(args...));
+ auto argY = std::get<1>(std::tie(args...));
+ auto argZ = std::get<2>(std::tie(args...));
+ using ArgX = decltype(argX);
+ using ArgY = decltype(argY);
+ using ArgZ = decltype(argZ);
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
+ // special case for ArithmeticRightShiftOp
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ bool round = std::get<2>(std::tie(args...));
+ if (EqualizeRanks(builder, x, y).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, round);
+ }
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value> &&
+ std::is_same_v<ArgZ, Value>) {
+ // special case for Select
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ Value z = std::get<2>(std::tie(args...));
+
+ if (EqualizeRanks(builder, x, y).failed() ||
+ EqualizeRanks(builder, x, z).failed() ||
+ EqualizeRanks(builder, y, z).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, z);
+ }
+ }
+ }
+
+ return createOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
+// Creates a TOSA operation by:
+// - first equalize ranks for ops with SameOperandsAndResultRank trait
+// - create operator
+// - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
+ Type result_ty, Args &&...args) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index a94bb3a920b1db..0779cdb9667a1a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -26,53 +26,6 @@ using namespace mlir::tosa;
namespace {
-template <typename TosaOp, typename... Args>
-TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
- Args &&...args) {
- auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
-
- InferShapedTypeOpInterface shapeInterface =
- dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
- if (!shapeInterface)
- return op;
-
- SmallVector<ShapedTypeComponents> returnedShapes;
- if (shapeInterface
- .inferReturnTypeComponents(
- op.getContext(), op.getLoc(), op->getOperands(),
- op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
- op->getRegions(), returnedShapes)
- .failed())
- return op;
-
- // We need to use the element type of the existing result type to generate
- // the new result shaped type. This is because rescale can include a cast to
- // different bit-width types and does not have a TypeAttr to define the
- // target type.
- auto result = op->getResult(0);
- auto predictedShape = returnedShapes[0];
- auto currentKnowledge =
- mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
-
- // Compute the knowledge based on the inferred type.
- auto inferredKnowledge =
- mlir::tosa::ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
- inferredKnowledge.hasRank = predictedShape.hasRank();
- if (predictedShape.hasRank()) {
- for (auto dim : predictedShape.getDims()) {
- inferredKnowledge.sizes.push_back(dim);
- }
- }
-
- // Compute the new type based on the joined version.
- auto newKnowledge =
- mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
- auto newTy = newKnowledge.getType();
- result.setType(newTy);
- return op;
-}
-
class TransposeConvNonStridedConverter
: public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
@@ -187,20 +140,20 @@ class TransposeConvStridedConverter
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
- Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
- weight = createOpAndInfer<tosa::PadOp>(
+ weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
} else {
- weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
- UnrankedTensorType::get(weightETy),
- weight, weightPaddingVal);
+ weight = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ weightPaddingVal);
}
weightTy = cast<ShapedType>(weight.getType());
@@ -212,7 +165,7 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
- weight = createOpAndInfer<tosa::ReshapeOp>(
+ weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
@@ -221,7 +174,7 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
- weight = createOpAndInfer<tosa::TransposeOp>(
+ weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal);
@@ -229,15 +182,15 @@ class TransposeConvStridedConverter
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
- weight = createOpAndInfer<tosa::ReshapeOp>(
+ weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
- weight = createOpAndInfer<tosa::ReverseOp>(
+ weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(1));
- weight = createOpAndInfer<tosa::ReverseOp>(
+ weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(2));
@@ -251,19 +204,19 @@ class TransposeConvStridedConverter
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
- Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
- input = createOpAndInfer<tosa::PadOp>(
+ input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
} else {
- input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
- UnrankedTensorType::get(inputETy),
- input, inputPaddingVal);
+ input = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(inputETy), input,
+ inputPaddingVal);
}
// We use a zero bias as we need to broadcast the bias.
@@ -279,7 +232,7 @@ class TransposeConvStridedConverter
// Perform the convolution using the zero bias.
Value conv2d;
if (op.getQuantizationInfo()) {
- conv2d = createOpAndInfer<tosa::Conv2DOp>(
+ conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -288,7 +241,7 @@ class TransposeConvStridedConverter
*op.getQuantizationInfo())
.getResult();
} else {
- conv2d = createOpAndInfer<tosa::Conv2DOp>(
+ conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -307,7 +260,7 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
- conv2d = createOpAndInfer<tosa::ReshapeOp>(
+ conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
@@ -316,14 +269,14 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
- conv2d = createOpAndInfer<tosa::TransposeOp>(
+ conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal);
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
- conv2d = createOpAndInfer<tosa::ReshapeOp>(
+ conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
@@ -348,7 +301,7 @@ class TransposeConvStridedConverter
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;
- auto slice = createOpAndInfer<tosa::SliceOp>(
+ auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(sliceBegin),
rewriter.getDenseI64ArrayAttr(sliceSize))
@@ -363,10 +316,10 @@ class TransposeConvStridedConverter
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
- Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
- Value resultPad = createOpAndInfer<tosa::PadOp>(
+ Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index f276924a8a9f62..1f6e3b2ab83919 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -102,6 +102,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return EqualizeRanks(builder, input1, input2);
+}
+
+LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
+ Value &input1, Value &input2) {
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
@@ -140,9 +146,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
- auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
- loc, reshapeOutputType, lowerTensorValue,
- rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
+ auto reshapeLower = builder.create<tosa::ReshapeOp>(
+ reshapeOutputType, lowerTensorValue,
+ builder.getDenseI64ArrayAttr(reshapeOutputShape));
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
More information about the Mlir-commits
mailing list