[Mlir-commits] [mlir] [mlir] fix a crash when fold reshpeOp that has a splat operand (#73190) (PR #76321)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 23 23:30:17 PST 2023


https://github.com/lipracer created https://github.com/llvm/llvm-project/pull/76321

None

>From 4f90bc68b15e673847f4f380c32c65f37a177809 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 24 Dec 2023 15:27:32 +0800
Subject: [PATCH] [mlir] fix a crash when fold reshpeOp that has a splat
 operand (#73190)

---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c929dee0f272..db2b5c147c3527 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -93,7 +93,11 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
     return reshapeSrcOp.getSrc();
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
-    return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+    return elements.isSplat()
+               ? elements.resizeSplat(
+                     cast<ShapedType>(reshapeOp.getResult().getType()))
+               : elements.reshape(
+                     cast<ShapedType>(reshapeOp.getResult().getType()));
   }
   return nullptr;
 }



More information about the Mlir-commits mailing list