[Mlir-commits] [mlir] Only fold splats for static shape result types (PR #93102)
Kavan Bickerstaff
llvmlistbot at llvm.org
Wed May 22 14:46:57 PDT 2024
https://github.com/KB9 created https://github.com/llvm/llvm-project/pull/93102
This prevents an assertion when constructing the DenseElementsAttr result, where the passed-in type is expected to have a static shape.
Fixes https://github.com/llvm/llvm-project/issues/92057
>From fb0aad45cf44d3c9cc5210e28c1211e15d00e210 Mon Sep 17 00:00:00 2001
From: KB9 <kavanbickerstaff at googlemail.com>
Date: Wed, 22 May 2024 21:26:33 +0100
Subject: [PATCH] Only fold splats for static shape result types
This prevents an assertion when constructing the DenseElementsAttr result,
where the passed-in type is expected to have a static shape.
---
mlir/include/mlir/Dialect/CommonFolders.h | 5 ++++-
mlir/test/Dialect/Arith/canonicalize.mlir | 8 ++++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h
index 7dabc781cd595..6f497a259262a 100644
--- a/mlir/include/mlir/Dialect/CommonFolders.h
+++ b/mlir/include/mlir/Dialect/CommonFolders.h
@@ -298,7 +298,10 @@ Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
calculate(op.getSplatValue<ElementValueT>(), castStatus);
if (!castStatus)
return {};
- return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
+ auto shapedResType = cast<ShapedType>(resType);
+ if (!shapedResType.hasStaticShape())
+ return {};
+ return DenseElementsAttr::get(shapedResType, elementResult);
}
if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
// Operand is ElementsAttr-derived; perform an element-wise fold by
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e4f95bb0545a2..1a387c20c4b29 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,6 +2950,14 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
+// Just checks that this doesn't crash.
+// CHECK-LABEL: @signedExtendSplatAsDynamicShape
+func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
+ %splat = arith.constant dense<5> : tensor<2xi16>
+ %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
+ return %extsplat : tensor<?xi64>
+}
+
// CHECK-LABEL: @extsi_i0
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
// CHECK: return %[[ZERO]] : i16
More information about the Mlir-commits
mailing list