[Mlir-commits] [mlir] [TOSA] Move CreateOpAndInfer into ConversionUtils.h (PR #106122)

Georgios Pinitas llvmlistbot at llvm.org
Wed Sep 4 13:51:06 PDT 2024


================
@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
 LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
                             Value &input1, Value &input2);
 
+LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
+                            Value &input2);
+
+namespace {
+
+// Creates a TOSA operation and performs shape inference on the individual
+// op. This allows shape inference during the TFLite to TOSA lowering.
+template <typename TosaOp, typename... Args>
+TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+                             Args &&...args) {
+  auto op = builder.create<TosaOp>(result_ty, args...);
+
+  InferShapedTypeOpInterface shapeInterface =
+      dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+  if (!shapeInterface)
+    return op;
+
+  SmallVector<ShapedTypeComponents> returnedShapes;
+  if (shapeInterface
+          .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
+                                     op->getOperands(), op->getAttrDictionary(),
+                                     op->getPropertiesStorage(),
+                                     op->getRegions(), returnedShapes)
+          .failed())
+    return op;
+
+  // We need to use the element type of the existing result type to generate
+  // the new result shaped type. This is because rescale can include a cast to
+  // different bit-width types and does not have a TypeAttr to define the
+  // target type.
+  auto result = op->getResult(0);
+  auto predictedShape = returnedShapes[0];
+  auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
+
+  // Compute the knowledge based on the inferred type.
+  auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+  inferredKnowledge.dtype = mlir::cast<ShapedType>(result_ty).getElementType();
+  inferredKnowledge.hasRank = predictedShape.hasRank();
+  if (predictedShape.hasRank()) {
+    for (auto dim : predictedShape.getDims()) {
+      inferredKnowledge.sizes.push_back(dim);
+    }
+  }
+
+  // Compute the new type based on the joined version.
+  auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+  Type new_ty =
----------------
GeorgeARM wrote:

nit: `newTy` 

https://github.com/llvm/llvm-project/pull/106122


More information about the Mlir-commits mailing list