[Mlir-commits] [mlir] Fixes in 'tosa.reshape' lowering and folder (PR #85798)

Suraj Sudhir llvmlistbot at llvm.org
Mon Mar 25 14:54:18 PDT 2024


================
@@ -19,24 +19,77 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
+#include <numeric>
+
 using namespace mlir;
 using namespace tosa;
 
-static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
-                                  ArrayRef<int64_t> rhsShape,
-                                  SmallVector<int64_t> &intermediateShape,
-                                  bool isDynamic) {
-  if (isDynamic) {
-    // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
-    intermediateShape = {ShapedType::kDynamic};
-    return true;
-  }
+namespace {
 
-  if (lhsShape.empty() || rhsShape.empty()) {
-    intermediateShape = {};
-    return true;
-  }
+// Infer the result type of 'tensor.expand_shape' in the collapse-expand
+// pair emitted for a 'tosa.reshape' op.
+TensorType inferReshapedType(TypedValue<TensorType> input,
+                             ArrayRef<int64_t> newShape) {
+  // Check if the input is static, and if so, get its total size
+  bool inputIsStatic = input.getType().hasStaticShape();
+  int64_t totalSize = inputIsStatic ? input.getType().getNumElements() : -1;
+ 
+  // Compute result shape
+  bool resultIsStatic = true;
+  auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
+    // If this is not a placeholder, do not change it
+    if (size >= 0)
+      return size;
+
+    // If we do not know the total size of the tensor, keep this dimension
+    // dynamic in the result shape.
+    if (!inputIsStatic) {
+      resultIsStatic = false;
+      return ShapedType::kDynamic;
+    }
+
+    // Calculate the product of all elements in 'newShape' except for the -1
+    // placeholder, which we discard by negating the result.
+    int64_t totalSizeNoPlaceholder = -std::accumulate(
+        newShape.begin(), newShape.end(), 1, std::multiplies());
----------------
sjarus wrote:

Thanks for this PR, @rafaelubalmw ! You're right to note the elimination of the explicit requirement that at most one dim be -1 . This was removed from the spec sometime ago (cc: @eric-k256) . 

Going forward, we plan to relax the shape parameter significantly. Any number of undetermined dims will be permitted, but by requiring the legalization to onboard dynamic shape resolution, potentially at runtime. 

The compile time resolvable construct will only permit up to one unknown dim, because as you said, that's all that makes sense. 

https://github.com/llvm/llvm-project/pull/85798


More information about the Mlir-commits mailing list