[Mlir-commits] [mlir] [mlir] Add Repeated<T> constructors for TypeRange and ValueRange (PR #186923)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Mar 16 16:58:15 PDT 2026
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/186923
Many MLIR APIs end up using a range of the same Type / Value repeated N times, due to the (function of the) dimensionality of the problem. Allocating a vector of N identical element is wasteful.
Add `Repeated<T>::Storage` as PointerUnion variants in TypeRange and ValueRange, enabling O(1) storage for repeated elements. Size remains 2 pointers (16 bytes on 64-bit) for both range types.
Also update several MLIR dialects and conversions to exercise the new code.
>From cb054c9522cc9c58d95c3b5eba26992e13af8999 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 16 Mar 2026 17:25:59 -0400
Subject: [PATCH] [mlir] Add Repeated<T> constructors for TypeRange and
ValueRange
Many MLIR APIs end up using a range of the same Type / Value repeated N
times, due to the dimensionality of the problem. Allocating a vector
of N identical element is wasteful.
Add `Repeated<T>::Storage` as PointerUnion variants in TypeRange
and ValueRange, enabling O(1) storage for repeated elements.
Size remains 2 pointers (16 bytes on 64-bit) for both range types.
Also update several MLIR dialects and conversions to exercise the new
code.
Co-Authored-By: Claude Opus 4.6 <noreply at anthropic.com>
---
mlir/include/mlir/IR/TypeRange.h | 16 ++++--
mlir/include/mlir/IR/ValueRange.h | 13 +++--
mlir/include/mlir/Support/LLVM.h | 3 ++
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 4 +-
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp | 4 +-
.../Conversion/VectorToAMX/VectorToAMX.cpp | 2 +-
.../OwnershipBasedBufferDeallocation.cpp | 2 +-
.../Transforms/VectorTransferOpTransforms.cpp | 4 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 2 +-
mlir/lib/IR/OperationSupport.cpp | 6 +++
mlir/lib/IR/TypeRange.cpp | 14 +++++
mlir/unittests/IR/OperationSupportTest.cpp | 51 +++++++++++++++++++
13 files changed, 107 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h
index c6cbf3461bcd7..3debed6212778 100644
--- a/mlir/include/mlir/IR/TypeRange.h
+++ b/mlir/include/mlir/IR/TypeRange.h
@@ -33,7 +33,9 @@ namespace mlir {
class TypeRange : public llvm::detail::indexed_accessor_range_base<
TypeRange,
llvm::PointerUnion<const Value *, const Type *,
- OpOperand *, detail::OpResultImpl *>,
+ OpOperand *, detail::OpResultImpl *,
+ const Repeated<Type>::Storage *,
+ const Repeated<Value>::Storage *>,
Type, Type, Type> {
public:
using RangeBaseT::RangeBaseT;
@@ -51,6 +53,10 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
: TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types LLVM_LIFETIME_BOUND)
: TypeRange(ArrayRef<Type>(types)) {}
+ /// Constructs a range from a repeated type. The Repeated object must outlive
+ /// this range.
+ TypeRange(const Repeated<Type> &repeatedValue LLVM_LIFETIME_BOUND)
+ : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
private:
/// The owner of the range is either:
@@ -58,8 +64,12 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
/// * A pointer to the first element of an array of types.
/// * A pointer to the first element of an array of operands.
/// * A pointer to the first element of an array of results.
- using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
- detail::OpResultImpl *>;
+ /// * A pointer to a Repeated<Type>::Storage (single type repeated N times).
+ /// * A pointer to a Repeated<Value>::Storage (single value repeated N times,
+ /// dereferenced via getType()).
+ using OwnerT = llvm::PointerUnion<
+ const Value *, const Type *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Type>::Storage *, const Repeated<Value>::Storage *>;
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(OwnerT object, ptrdiff_t index);
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f04ed0544c0f6..d40de878d5d10 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -17,6 +17,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/Repeated.h"
#include "llvm/ADT/Sequence.h"
#include <optional>
@@ -383,13 +384,15 @@ class ResultRange::UseIterator final
class ValueRange final
: public llvm::detail::indexed_accessor_range_base<
ValueRange,
- PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
+ PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Value>::Storage *>,
Value, Value, Value> {
public:
/// The type representing the owner of a ValueRange. This is either a list of
- /// values, operands, or results.
+ /// values, operands, results, or a repeated single value.
using OwnerT =
- PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
+ PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *,
+ const Repeated<Value>::Storage *>;
using RangeBaseT::RangeBaseT;
@@ -412,6 +415,10 @@ class ValueRange final
ValueRange(ArrayRef<Value> values = {});
ValueRange(OperandRange values);
ValueRange(ResultRange values);
+ /// Constructs a range from a repeated value. The Repeated object must outlive
+ /// this range.
+ ValueRange(const Repeated<Value> &repeatedValue LLVM_LIFETIME_BOUND)
+ : RangeBaseT(&repeatedValue.storage, repeatedValue.count) {}
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h
index 81bbd717c4f80..8b27ee8c4fed6 100644
--- a/mlir/include/mlir/Support/LLVM.h
+++ b/mlir/include/mlir/Support/LLVM.h
@@ -54,6 +54,8 @@ template <typename T>
class MutableArrayRef;
template <typename... PT>
class PointerUnion;
+template <typename T>
+struct Repeated;
template <typename T, typename Vector, typename Set, unsigned N>
class SetVector;
template <typename T, unsigned N>
@@ -125,6 +127,7 @@ template <typename AllocatorTy = llvm::MallocAllocator>
using StringSet = llvm::StringSet<AllocatorTy>;
using llvm::MutableArrayRef;
using llvm::PointerUnion;
+using llvm::Repeated;
using llvm::SmallPtrSet;
using llvm::SmallPtrSetImpl;
using llvm::SmallVector;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 498bea0fd17b4..6a705ebab7aa4 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -95,7 +95,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// New arguments will simply be `llvm.ptr` with the correct address space
Type workgroupPtrType =
rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
- SmallVector<Type> argTypes(numAttributions, workgroupPtrType);
+ Repeated<Type> argTypes(numAttributions, workgroupPtrType);
// Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
std::array attrs{
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 78f0fe1392962..e4b5da7a5ea92 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -155,11 +155,11 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
int count = vectorType.getNumElements();
intType = VectorType::get(count, intType);
- SmallVector<Value> signSplat(count, signMask);
+ Repeated<Value> signSplat(count, signMask);
signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
signSplat);
- SmallVector<Value> valueSplat(count, valueMask);
+ Repeated<Value> valueSplat(count, valueMask);
valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
valueSplat);
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index b46026b855b90..7e9c9090c51df 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -117,8 +117,8 @@ class ScatterOpConverter : public OpRewritePattern<tosa::ScatterOp> {
auto one = createIndexConst(rewriter, loc, 1);
// Loop bounds
- auto lbs = llvm::SmallVector<Value>(2, zero);
- auto steps = llvm::SmallVector<Value>(2, one);
+ auto lbs = Repeated<Value>(2, zero);
+ auto steps = Repeated<Value>(2, one);
auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index bce67b3e4748b..c6182379026df 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -354,7 +354,7 @@ static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
rewriter, loc,
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
- SmallVector<Value> indices(2, zeroIndex);
+ Repeated<Value> indices(2, zeroIndex);
x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index e625f172a3bf3..f73c8476bf20e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -723,7 +723,7 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
// the outside.
Value falseVal = buildBoolValue(builder, op.getLoc(), false);
op->insertOperands(op->getNumOperands(),
- SmallVector<Value>(numMemrefOperands, falseVal));
+ Repeated<Value>(numMemrefOperands, falseVal));
int counter = op->getNumResults();
unsigned numMemrefResults = llvm::count_if(op->getResults(), isMemref);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 19db8b3b48a25..babd321e484bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -544,7 +544,7 @@ class TransferReadDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<Value> zeros(reducedRank, c0);
+ Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
Operation *newTransferReadOp = vector::TransferReadOp::create(
@@ -658,7 +658,7 @@ class TransferWriteDropUnitDimsPattern
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
- SmallVector<Value> zeros(reducedRank, c0);
+ Repeated<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index df8e6cf167348..9585f5a1d774a 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -357,7 +357,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
builder, loc,
/*vectorType=*/vecToReadTy,
/*source=*/source,
- /*indices=*/SmallVector<Value>(vecToReadRank, zero),
+ /*indices=*/Repeated<Value>(vecToReadRank, zero),
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 3ff61daaac60b..f1ee879136756 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -654,6 +654,9 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
return {value + index};
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return {operand + index};
+ // All elements are identical; the owner pointer never advances.
+ if (llvm::isa<const Repeated<Value>::Storage *>(owner))
+ return owner;
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
@@ -662,6 +665,9 @@ Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
return value[index];
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
return operand[index].get();
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(owner))
+ return repeated->value;
return cast<detail::OpResultImpl *>(owner)->getNextResultAtOffset(index);
}
diff --git a/mlir/lib/IR/TypeRange.cpp b/mlir/lib/IR/TypeRange.cpp
index 88e788aa1b2b8..d14ad76c83e75 100644
--- a/mlir/lib/IR/TypeRange.cpp
+++ b/mlir/lib/IR/TypeRange.cpp
@@ -31,6 +31,10 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
this->base = result;
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
this->base = operand;
+ else if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(
+ owner))
+ this->base = repeated;
else
this->base = cast<const Value *>(owner);
}
@@ -43,6 +47,10 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
return {operand + index};
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return {result->getNextResultAtOffset(index)};
+ // All elements are identical; the owner pointer never advances.
+ if (llvm::isa<const Repeated<Type>::Storage *,
+ const Repeated<Value>::Storage *>(object))
+ return object;
return {llvm::dyn_cast_if_present<const Type *>(object) + index};
}
@@ -54,5 +62,11 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
return (operand + index)->get().getType();
if (auto *result = llvm::dyn_cast_if_present<detail::OpResultImpl *>(object))
return result->getNextResultAtOffset(index)->getType();
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Type>::Storage *>(object))
+ return repeated->value;
+ if (auto *repeated =
+ llvm::dyn_cast_if_present<const Repeated<Value>::Storage *>(object))
+ return repeated->value.getType();
return llvm::dyn_cast_if_present<const Type *>(object)[index];
}
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index f00d5c1f7f927..6319fcbb0f216 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -11,7 +11,9 @@
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeRange.h"
#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/Repeated.h"
#include "llvm/Support/FormatVariadic.h"
#include "gtest/gtest.h"
@@ -375,4 +377,53 @@ TEST(OperationCloneTest, CloneWithDifferentResults) {
cloneOp->destroy();
}
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedType) {
+ MLIRContext context;
+ Builder builder(&context);
+ Type i32 = builder.getI32Type();
+
+ llvm::Repeated<Type> rep(3, i32);
+ TypeRange range(rep);
+
+ EXPECT_EQ(range.size(), 3u);
+ EXPECT_FALSE(range.empty());
+ for (Type t : range)
+ EXPECT_EQ(t, i32);
+
+ llvm::Repeated<Type> emptyRep(0, Type{});
+ TypeRange emptyTypeRange(emptyRep);
+
+ EXPECT_EQ(emptyTypeRange.size(), 0u);
+ EXPECT_TRUE(emptyTypeRange.empty());
+}
+
+TEST(RepeatedRangeTest, ValueRangeFromRepeatedValue) {
+ Value nullVal;
+ llvm::Repeated<Value> rep(4, nullVal);
+ ValueRange range(rep);
+
+ EXPECT_EQ(range.size(), 4u);
+ EXPECT_FALSE(range.empty());
+ for (Value v : range)
+ EXPECT_EQ(v, nullVal);
+}
+
+TEST(RepeatedRangeTest, TypeRangeFromRepeatedValueViaValueRange) {
+ MLIRContext context;
+ Builder builder(&context);
+ Type i32 = builder.getI32Type();
+
+ Operation *op = createOp(&context, /*operands=*/{}, i32);
+ Value val = op->getResult(0);
+
+ llvm::Repeated<Value> rep(3, val);
+ TypeRange tr = ValueRange(rep);
+
+ EXPECT_EQ(tr.size(), 3u);
+ for (Type t : tr)
+ EXPECT_EQ(t, i32);
+
+ op->destroy();
+}
+
} // namespace
More information about the Mlir-commits
mailing list