[Mlir-commits] [mlir] [WIP][mlir] Make `DenseElementsAttr::reshape(...)` take a shape instead of a type (PR #149947)
James Newling
llvmlistbot at llvm.org
Mon Jul 21 16:55:21 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/149947
This is a proposal. The motivation is that it would have prevented the issue detected in https://github.com/llvm/llvm-project/pull/147691.
Can the type of `DenseElementsAttr` be a VectorType with scalable dimensions? If so, this approach probably won't work.
>From cc6aae0976d1a67960162598a540cbacf0608875 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:30:07 -0700
Subject: [PATCH 1/2] reshape to a new shape not a new type
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 3 ++-
mlir/include/mlir/IR/BuiltinAttributes.h | 6 ++++--
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 5 +++--
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++--
mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++--
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
mlir/lib/IR/BuiltinAttributes.cpp | 4 +++-
9 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 704e39e908841..abe40227b31dc 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
- return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ return elements.reshape(
+ cast<ShapedType>(reshapeOp.getResult().getType()).getShape());
// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c07ade606a775..23f0a72d7fd00 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -615,8 +615,10 @@ class DenseElementsAttr : public Attribute {
//===--------------------------------------------------------------------===//
/// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but has been reshaped to 'newType'. The new type must have the
- /// same total number of elements as well as element type.
+ /// attribute, but has been reshaped to 'newShape'. The new shape must have
+ /// the same total number of elements.
+ DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
+
DenseElementsAttr reshape(ShapedType newType);
/// Return a new DenseElementsAttr that has the same data as the current
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d57ab6b59e79..f81832fcb981e 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
- return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
- .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
+ return wrap(
+ llvm::cast<DenseElementsAttr>(unwrap(attr))
+ .reshape(llvm::cast<ShapedType>(unwrap(shapedType)).getShape()));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..dd11f4f2bafda 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final
if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
- dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape());
} else {
// TODO: add support for large vectors.
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5758d8d5ef506..adeecb23528db 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec).getShape());
}
return {};
@@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
- return input.reshape(resultTy);
+ return input.reshape(resultTy.getShape());
}
// Transpose is not the identity transpose.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index db7a3c671dedc..9090080534bc7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
RankedTensorType::get(newShape, oldType.getElementType());
if (input.isSplat()) {
- return input.reshape(newType);
+ return input.reshape(newType.getShape());
}
auto rawData = input.getRawData();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d615bfc12984..b109426caa62e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6000,7 +6000,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ return splatAttr.reshape(getType().getShape());
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6346,7 +6346,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
- return splat.reshape(getResultVectorType());
+ return splat.reshape(getResultVectorType().getShape());
// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fe17b3c0b2cfc..515dd14081626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- return dstElementsAttr.reshape(resType);
+ return dstElementsAttr.reshape(resType.getShape());
}
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fd898b7493c7f..81b2213dd5a93 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1241,8 +1241,10 @@ ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
-DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+DenseElementsAttr DenseElementsAttr::reshape(ArrayRef<int64_t> newShape) {
+
ShapedType curType = getType();
+ auto newType = curType.cloneWith(newShape, curType.getElementType());
if (curType == newType)
return *this;
>From 8035784aa3b617d2751b98e2ed9f2065d4898fe6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:32:24 -0700
Subject: [PATCH 2/2] remove not unused API
---
mlir/include/mlir/IR/BuiltinAttributes.h | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 23f0a72d7fd00..ee26537d20e8c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -619,8 +619,6 @@ class DenseElementsAttr : public Attribute {
/// the same total number of elements.
DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
- DenseElementsAttr reshape(ShapedType newType);
-
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but with a different shape for a splat type. The new type must
/// have the same element type.
More information about the Mlir-commits
mailing list