[Mlir-commits] [mlir] [TOSA] Move CreateOpAndInfer into ConversionUtils.h (PR #106122)
Georgios Pinitas
llvmlistbot at llvm.org
Wed Sep 4 13:51:06 PDT 2024
================
@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2);
+LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
+ Value &input2);
+
+namespace {
+
+// Creates a TOSA operation and performs shape inference on the individual
+// op. This allows shape inference during the TFLite to TOSA lowering.
+template <typename TosaOp, typename... Args>
+TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+ Args &&...args) {
+ auto op = builder.create<TosaOp>(result_ty, args...);
+
+ InferShapedTypeOpInterface shapeInterface =
+ dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+ if (!shapeInterface)
+ return op;
+
+ SmallVector<ShapedTypeComponents> returnedShapes;
+ if (shapeInterface
+ .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
+ op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
+ op->getRegions(), returnedShapes)
+ .failed())
+ return op;
+
+ // We need to use the element type of the existing result type to generate
+ // the new result shaped type. This is because rescale can include a cast to
+ // different bit-width types and does not have a TypeAttr to define the
+ // target type.
+ auto result = op->getResult(0);
+ auto predictedShape = returnedShapes[0];
+ auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
+
+ // Compute the knowledge based on the inferred type.
+ auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+ inferredKnowledge.dtype = mlir::cast<ShapedType>(result_ty).getElementType();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
+ if (predictedShape.hasRank()) {
+ for (auto dim : predictedShape.getDims()) {
+ inferredKnowledge.sizes.push_back(dim);
+ }
+ }
+
+ // Compute the new type based on the joined version.
+ auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ Type new_ty =
----------------
GeorgeARM wrote:
nit: `newTy`
https://github.com/llvm/llvm-project/pull/106122
More information about the Mlir-commits
mailing list