[Mlir-commits] [mlir] [WIP][mlir] Make `DenseElementsAttr::reshape(...)` take a shape instead of a type (PR #149947)

James Newling llvmlistbot at llvm.org
Tue Jul 22 09:12:11 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/149947

>From cc6aae0976d1a67960162598a540cbacf0608875 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:30:07 -0700
Subject: [PATCH 1/4] reshape to a new shape not a new type

---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h         | 3 ++-
 mlir/include/mlir/IR/BuiltinAttributes.h                  | 6 ++++--
 mlir/lib/CAPI/IR/BuiltinAttributes.cpp                    | 5 +++--
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp         | 2 +-
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp        | 4 ++--
 mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp | 2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp                  | 4 ++--
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp    | 2 +-
 mlir/lib/IR/BuiltinAttributes.cpp                         | 4 +++-
 9 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 704e39e908841..abe40227b31dc 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
 
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
-    return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+    return elements.reshape(
+        cast<ShapedType>(reshapeOp.getResult().getType()).getShape());
 
   // Fold if the producer reshape source has the same shape with at most 1
   // dynamic dimension.
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index c07ade606a775..23f0a72d7fd00 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -615,8 +615,10 @@ class DenseElementsAttr : public Attribute {
   //===--------------------------------------------------------------------===//
 
   /// Return a new DenseElementsAttr that has the same data as the current
-  /// attribute, but has been reshaped to 'newType'. The new type must have the
-  /// same total number of elements as well as element type.
+  /// attribute, but has been reshaped to 'newShape'. The new shape must have
+  /// the same total number of elements.
+  DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
+
   DenseElementsAttr reshape(ShapedType newType);
 
   /// Return a new DenseElementsAttr that has the same data as the current
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 8d57ab6b59e79..f81832fcb981e 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
 
 MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
                                               MlirType shapedType) {
-  return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
-                  .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
+  return wrap(
+      llvm::cast<DenseElementsAttr>(unwrap(attr))
+          .reshape(llvm::cast<ShapedType>(unwrap(shapedType)).getShape()));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..dd11f4f2bafda 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final
       if (isa<RankedTensorType>(srcType)) {
         dstAttrType = RankedTensorType::get(srcType.getNumElements(),
                                             srcType.getElementType());
-        dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
+        dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape());
       } else {
         // TODO: add support for large vectors.
         return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 5758d8d5ef506..adeecb23528db 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
       return {};
 
     return operand.reshape(
-        llvm::cast<ShapedType>(operand.getType()).clone(shapeVec));
+        llvm::cast<ShapedType>(operand.getType()).clone(shapeVec).getShape());
   }
 
   return {};
@@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
     if (input.isSplat() && resultTy.hasStaticShape() &&
         input.getType().getElementType() == resultTy.getElementType())
-      return input.reshape(resultTy);
+      return input.reshape(resultTy.getShape());
   }
 
   // Transpose is not the identity transpose.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index db7a3c671dedc..9090080534bc7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
       RankedTensorType::get(newShape, oldType.getElementType());
 
   if (input.isSplat()) {
-    return input.reshape(newType);
+    return input.reshape(newType.getShape());
   }
 
   auto rawData = input.getRawData();
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d615bfc12984..b109426caa62e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6000,7 +6000,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
   // shape_cast(constant) -> constant
   if (auto splatAttr =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
-    return splatAttr.reshape(getType());
+    return splatAttr.reshape(getType().getShape());
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6346,7 +6346,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
   if (auto splat =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
-    return splat.reshape(getResultVectorType());
+    return splat.reshape(getResultVectorType().getShape());
 
   // Eliminate poison transpose ops.
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index fe17b3c0b2cfc..515dd14081626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    return dstElementsAttr.reshape(resType);
+    return dstElementsAttr.reshape(resType.getShape());
   }
 
   if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index fd898b7493c7f..81b2213dd5a93 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1241,8 +1241,10 @@ ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
 /// Return a new DenseElementsAttr that has the same data as the current
 /// attribute, but has been reshaped to 'newType'. The new type must have the
 /// same total number of elements as well as element type.
-DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
+DenseElementsAttr DenseElementsAttr::reshape(ArrayRef<int64_t> newShape) {
+
   ShapedType curType = getType();
+  auto newType = curType.cloneWith(newShape, curType.getElementType());
   if (curType == newType)
     return *this;
 

>From 8035784aa3b617d2751b98e2ed9f2065d4898fe6 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Jul 2025 16:32:24 -0700
Subject: [PATCH 2/4] remove not unused API

---
 mlir/include/mlir/IR/BuiltinAttributes.h | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 23f0a72d7fd00..ee26537d20e8c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -619,8 +619,6 @@ class DenseElementsAttr : public Attribute {
   /// the same total number of elements.
   DenseElementsAttr reshape(ArrayRef<int64_t> newShape);
 
-  DenseElementsAttr reshape(ShapedType newType);
-
   /// Return a new DenseElementsAttr that has the same data as the current
   /// attribute, but with a different shape for a splat type. The new type must
   /// have the same element type.

>From e7613ead02ba850ee73748e8ac9e3d9fa38714c4 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Jul 2025 09:03:20 -0700
Subject: [PATCH 3/4] new clone approach for vector types

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  7 ++-
 .../Vector/Transforms/VectorLinearize.cpp     |  2 +-
 mlir/lib/IR/BuiltinTypes.cpp                  | 58 ++++++++++++++++++-
 3 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b109426caa62e..154493634dd02 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5999,8 +5999,9 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
     return splatAttr.reshape(getType().getShape());
+  }
 
   // shape_cast(poison) -> poison
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
@@ -6346,7 +6347,9 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
   if (auto splat =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
-    return splat.reshape(getResultVectorType().getShape());
+
+    return DenseElementsAttr::get(getResultVectorType(),
+                                  splat.getSplatValue<Attribute>());
 
   // Eliminate poison transpose ops.
   if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 515dd14081626..8da95dd48d8f4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    return dstElementsAttr.reshape(resType.getShape());
+    return dstElementsAttr.reshape(resType.getNumElements());
   }
 
   if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..5bdb0f8701cbf 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::detail;
@@ -244,10 +245,61 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
   return VectorType();
 }
 
-VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> maybeShape,
                                  Type elementType) const {
-  return VectorType::get(shape.value_or(getShape()), elementType,
-                         getScalableDims());
+
+  // Case where only the element type is modified:
+  if (!maybeShape.has_value())
+    return VectorType::get(getShape(), elementType, getScalableDims());
+
+  ArrayRef<int64_t> shape = maybeShape.value();
+  int64_t rankBefore = getRank();
+  int64_t rankAfter = static_cast<int64_t>(shape.size());
+
+  // In the case where the rank is unchanged, the positions of the scalable
+  // dimensions are retained.
+  // Example: vector<4x[1]xf32> -> vector<1x[4]xi8>
+  if (rankBefore == rankAfter)
+    return VectorType::get(shape, elementType, getScalableDims());
+
+  // In the case where the rank increases, retain the scalable dimension
+  // position relative to front (outermost dimension).
+  // Example: vector<4x[1]xf32> -> vector<1x[2]x2x1xi8>
+  if (rankBefore < rankAfter) {
+    SmallVector<bool> newScalableDims(rankAfter, false);
+    std::copy(getScalableDims().begin(), getScalableDims().end(),
+              newScalableDims.begin() + (rankAfter - rankBefore));
+    return VectorType::get(shape, elementType, newScalableDims);
+  }
+
+  // In the case where the rank decreases, retain the first `rankAfter` scalable
+  // dimensions. Any scalable dimensions in the final `rankBefore - rankAfter`
+  // dimensions are packed into gaps, if possible.
+  //
+  // Examples:
+  //
+  // vector<4x[1]xf32> -> vector<[4]xi8>
+  // vector<[4]x1xf32> -> vector<[4]xi8>
+  // vector<[2]x3x[4]x5xf32> -> vector<[6]x[20]xi8>
+  //
+  // If the number of scalable dimensions excedes the number of dimensions in
+  // the new shape, there is an assertion failure.
+  assert(rankAfter < rankBefore);
+  SmallVector<bool> newScalableDims(getScalableDims().take_front(rankAfter));
+  int nScalablesToRelocate =
+      llvm::count_if(getScalableDims().take_back(rankBefore - rankAfter),
+                     [](bool b) { return b; });
+  int currentIndex = newScalableDims.size() - 1;
+  while (nScalablesToRelocate > 0 && currentIndex >= 0) {
+    if (!newScalableDims[currentIndex]) {
+      newScalableDims[currentIndex] = true;
+      --nScalablesToRelocate;
+    }
+  }
+
+  assert(nScalablesToRelocate == 0 &&
+         "too many scalable dimensions for new (lower) rank");
+  return VectorType::get(shape, elementType, newScalableDims);
 }
 
 //===----------------------------------------------------------------------===//

>From 30d09acad38ab15bbbcaf41a2e65012213a04527 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 22 Jul 2025 09:13:14 -0700
Subject: [PATCH 4/4] cosmetics

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp               | 7 ++-----
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 154493634dd02..3e8927069cf77 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5999,14 +5999,12 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(constant) -> constant
   if (auto splatAttr =
-          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
     return splatAttr.reshape(getType().getShape());
-  }
 
   // shape_cast(poison) -> poison
-  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
     return ub::PoisonAttr::get(getContext());
-  }
 
   return {};
 }
@@ -6347,7 +6345,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
   if (auto splat =
           llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
-
     return DenseElementsAttr::get(getResultVectorType(),
                                   splat.getSplatValue<Attribute>());
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 8da95dd48d8f4..515dd14081626 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter,
           loc,
           "Cannot linearize a constant scalable vector that's not a splat");
 
-    return dstElementsAttr.reshape(resType.getNumElements());
+    return dstElementsAttr.reshape(resType.getShape());
   }
 
   if (auto poisonAttr = dyn_cast<ub::PoisonAttr>(value))



More information about the Mlir-commits mailing list