[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)

Spenser Bauman llvmlistbot at llvm.org
Fri Mar 22 08:00:36 PDT 2024


================
@@ -19,24 +19,77 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace tosa;
 
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
+namespace {
 
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
+// Infer the result type of 'tensor.expand_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferReshapedType(TypedValue<TensorType> input,
+                             ArrayRef<int64_t> newShape) {
+  // Check if the input is static, and if so, get its total size
+  bool inputIsStatic = input.getType().hasStaticShape();
+  int64_t totalSize = inputIsStatic ? input.getType().getNumElements() : -1;
+ 
+  // Compute result shape
+  bool resultIsStatic = true;
+  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+    // If this is not a placeholder, do not change it
+    if (size >= 0)
+      return size;
+
+    // If we do not know the total size of the tensor, keep this dimension
+    // dynamic in the result shape.
+    if (!inputIsStatic) {
+      resultIsStatic = false;
+      return ShapedType::kDynamic;
+    }
+
+    // Calculate the product of all elements in 'newShape' except for the -1
+    // placeholder, which we discard by negating the result.
+    int64_t totalSizeNoPlaceholder = -std::accumulate(
+        newShape.begin(), newShape.end(), 1, std::multiplies());
----------------
sabauma wrote:

Is it true that the assumption here that `newShape` contains at most a single -1? Could you add an assertion to that effect at the top of the function?

https://github.com/llvm/llvm-project/pull/85798


More information about the Mlir-commits mailing list