[Mlir-commits] [mlir] [mlir][tensor] Fold `tensor.reshape` for dynamic reshape (PR #88961)

Rob Suderman llvmlistbot at llvm.org
Thu Apr 18 11:33:02 PDT 2024


================
@@ -1580,6 +1580,41 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getResult().getType()))
     return reshapedSource;
+
+  auto source = getSource();
+  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
+  auto resultTy = dyn_cast<RankedTensorType>(getType());
+
+  if (!sourceTy || !resultTy || sourceTy != resultTy)
+    return {};
+
+  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
+    auto elements = fromElements.getElements();
+    bool dynamicNoop =
+        sourceTy.getRank() == static_cast<int64_t>(elements.size());
+    for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
+      auto element = elements[id];
+
+      APSInt cstElement;
+      if (matchPattern(element, m_ConstantInt(&cstElement))) {
+        dynamicNoop &= cstElement.getExtValue() == sourceTy.getDimSize(id);
+        continue;
+      }
+
+      if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
+        dynamicNoop &= dimOp.getSource() == source;
+
+        APSInt dim;
+        dynamicNoop &= matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) &&
----------------
rsuderman wrote:

Fixed.

https://github.com/llvm/llvm-project/pull/88961


More information about the Mlir-commits mailing list