[Mlir-commits] [mlir] [WIP][mlir] Make `DenseElementsAttr::reshape(...)` take a shape instead of a type (PR #149947)
James Newling
llvmlistbot at llvm.org
Tue Jul 22 09:12:11 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/149947
>From cc6aae0976d1a67960162598a540cbacf0608875 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:30:07 -0700
Subject: [PATCH 1/4] reshape to a new shape not a new type
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 3 ++-
mlir/include/mlir/IR/BuiltinAttributes.h | 6 ++++--
mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 5 +++--
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 2 +-
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++--
mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++--
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
mlir/lib/IR/BuiltinAttributes.cpp | 4 +++-
9 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 704e39e908841..abe40227b31dc 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
- return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+ return elements.reshape(
+ cast<ShapedType>(reshapeOp.getResult().getType()).getShape());
// Fold if the producer reshape source has the same shape with at most 1
// dynamic dimension.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c07ade606a775..23f0a72d7fd00 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -615,8 +615,10 @@ class DenseElementsAttr : public Attribute {
//===--------------------------------------------------------------------===//
/// Return a new DenseElementsAttr that has the same data as the current
- /// attribute, but has been reshaped to 'newType'. The new type must have the
- /// same total number of elements as well as element type.
+ /// attribute, but has been reshaped to 'newShape'. The new shape must have
+ /// the same total number of elements.
+ DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
+
DenseElementsAttr reshape(ShapedType newType);
/// Return a new DenseElementsAttr that has the same data as the current
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d57ab6b59e79..f81832fcb981e 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
MlirType shapedType) {
- return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
- .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
+ return wrap(
+ llvm::cast<DenseElementsAttr>(unwrap(attr))
+ .reshape(llvm::cast<ShapedType>(unwrap(shapedType)).getShape()));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..dd11f4f2bafda 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final
if (isa<RankedTensorType>(srcType)) {
dstAttrType = RankedTensorType::get(srcType.getNumElements(),
srcType.getElementType());
- dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape());
} else {
// TODO: add support for large vectors.
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5758d8d5ef506..adeecb23528db 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
return {};
return operand.reshape(
- llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
+ llvm::cast<ShapedType>(operand.getType()).clone(shapeVec).getShape());
}
return {};
@@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
if (input.isSplat() && resultTy.hasStaticShape() &&
input.getType().getElementType() == resultTy.getElementType())
- return input.reshape(resultTy);
+ return input.reshape(resultTy.getShape());
}
// Transpose is not the identity transpose.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index db7a3c671dedc..9090080534bc7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
RankedTensorType::get(newShape, oldType.getElementType());
if (input.isSplat()) {
- return input.reshape(newType);
+ return input.reshape(newType.getShape());
}
auto rawData = input.getRawData();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d615bfc12984..b109426caa62e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6000,7 +6000,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ return splatAttr.reshape(getType().getShape());
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6346,7 +6346,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
- return splat.reshape(getResultVectorType());
+ return splat.reshape(getResultVectorType().getShape());
// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fe17b3c0b2cfc..515dd14081626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- return dstElementsAttr.reshape(resType);
+ return dstElementsAttr.reshape(resType.getShape());
}
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fd898b7493c7f..81b2213dd5a93 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1241,8 +1241,10 @@ ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
-DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+DenseElementsAttr DenseElementsAttr::reshape(ArrayRef<int64_t> newShape) {
+
ShapedType curType = getType();
+ auto newType = curType.cloneWith(newShape, curType.getElementType());
if (curType == newType)
return *this;
>From 8035784aa3b617d2751b98e2ed9f2065d4898fe6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:32:24 -0700
Subject: [PATCH 2/4] remove not unused API
---
mlir/include/mlir/IR/BuiltinAttributes.h | 2 --
1 file changed, 2 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 23f0a72d7fd00..ee26537d20e8c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -619,8 +619,6 @@ class DenseElementsAttr : public Attribute {
/// the same total number of elements.
DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
- DenseElementsAttr reshape(ShapedType newType);
-
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but with a different shape for a splat type. The new type must
/// have the same element type.
>From e7613ead02ba850ee73748e8ac9e3d9fa38714c4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Jul 2025 09:03:20 -0700
Subject: [PATCH 3/4] new clone approach for vector types
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++-
.../Vector/Transforms/VectorLinearize.cpp | 2 +-
mlir/lib/IR/BuiltinTypes.cpp | 58 ++++++++++++++++++-
3 files changed, 61 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b109426caa62e..154493634dd02 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5999,8 +5999,9 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
return splatAttr.reshape(getType().getShape());
+ }
// shape_cast(poison) -> poison
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6346,7 +6347,9 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
- return splat.reshape(getResultVectorType().getShape());
+
+ return DenseElementsAttr::get(getResultVectorType(),
+ splat.getSplatValue<Attribute>());
// Eliminate poison transpose ops.
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 515dd14081626..8da95dd48d8f4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- return dstElementsAttr.reshape(resType.getShape());
+ return dstElementsAttr.reshape(resType.getNumElements());
}
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..5bdb0f8701cbf 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::detail;
@@ -244,10 +245,61 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
}
-VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> maybeShape,
Type elementType) const {
- return VectorType::get(shape.value_or(getShape()), elementType,
- getScalableDims());
+
+ // Case where only the element type is modified:
+ if (!maybeShape.has_value())
+ return VectorType::get(getShape(), elementType, getScalableDims());
+
+ ArrayRef<int64_t> shape = maybeShape.value();
+ int64_t rankBefore = getRank();
+ int64_t rankAfter = static_cast<int64_t>(shape.size());
+
+ // In the case where the rank is unchanged, the positions of the scalable
+ // dimensions are retained.
+ // Example: vector<4x[1]xf32> -> vector<1x[4]xi8>
+ if (rankBefore == rankAfter)
+ return VectorType::get(shape, elementType, getScalableDims());
+
+ // In the case where the rank increases, retain the scalable dimension
+ // position relative to front (outermost dimension).
+ // Example: vector<4x[1]xf32> -> vector<1x[2]x2x1xi8>
+ if (rankBefore < rankAfter) {
+ SmallVector<bool> newScalableDims(rankAfter, false);
+ std::copy(getScalableDims().begin(), getScalableDims().end(),
+ newScalableDims.begin() + (rankAfter - rankBefore));
+ return VectorType::get(shape, elementType, newScalableDims);
+ }
+
+ // In the case where the rank decreases, retain the first `rankAfter` scalable
+ // dimensions. Any scalable dimensions in the final `rankBefore - rankAfter`
+ // dimensions are packed into gaps, if possible.
+ //
+ // Examples:
+ //
+ // vector<4x[1]xf32> -> vector<[4]xi8>
+ // vector<[4]x1xf32> -> vector<[4]xi8>
+ // vector<[2]x3x[4]x5xf32> -> vector<[6]x[20]xi8>
+ //
+ // If the number of scalable dimensions excedes the number of dimensions in
+ // the new shape, there is an assertion failure.
+ assert(rankAfter < rankBefore);
+ SmallVector<bool> newScalableDims(getScalableDims().take_front(rankAfter));
+ int nScalablesToRelocate =
+ llvm::count_if(getScalableDims().take_back(rankBefore - rankAfter),
+ [](bool b) { return b; });
+ int currentIndex = newScalableDims.size() - 1;
+ while (nScalablesToRelocate > 0 && currentIndex >= 0) {
+ if (!newScalableDims[currentIndex]) {
+ newScalableDims[currentIndex] = true;
+ --nScalablesToRelocate;
+ }
+ }
+
+ assert(nScalablesToRelocate == 0 &&
+ "too many scalable dimensions for new (lower) rank");
+ return VectorType::get(shape, elementType, newScalableDims);
}
//===----------------------------------------------------------------------===//
>From 30d09acad38ab15bbbcaf41a2e65012213a04527 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Jul 2025 09:13:14 -0700
Subject: [PATCH 4/4] cosmetics
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 7 ++-----
mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
2 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 154493634dd02..3e8927069cf77 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5999,14 +5999,12 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(constant) -> constant
if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
return splatAttr.reshape(getType().getShape());
- }
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
@@ -6347,7 +6345,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// Eliminate splat constant transpose ops.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
-
return DenseElementsAttr::get(getResultVectorType(),
splat.getSplatValue<Attribute>());
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 8da95dd48d8f4..515dd14081626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- return dstElementsAttr.reshape(resType.getNumElements());
+ return dstElementsAttr.reshape(resType.getShape());
}
if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
More information about the Mlir-commits
mailing list