[Mlir-commits] [mlir] fb4c05c - [mlir][IR] Add implicit conversion operator to `TypedValue` (#164621)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 23 00:24:34 PDT 2025
Author: Matthias Springer
Date: 2025-10-23T09:24:30+02:00
New Revision: fb4c05cf036e09ed97a48a6c515befbcc9198c61
URL: https://github.com/llvm/llvm-project/commit/fb4c05cf036e09ed97a48a6c515befbcc9198c61
DIFF: https://github.com/llvm/llvm-project/commit/fb4c05cf036e09ed97a48a6c515befbcc9198c61.diff
LOG: [mlir][IR] Add implicit conversion operator to `TypedValue` (#164621)
Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B`
is assignable to `A`.
Example:
```c++
TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val; // this is now valid
```
Added:
Modified:
mlir/include/mlir/IR/Value.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Shard/Transforms/Partition.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 4d6d89fa69a07..af58778a0a13e 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
template <typename Ty>
struct TypedValue : Value {
using Value::Value;
+ using ValueType = Ty;
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
+ /// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
+ /// to A.
+ template <typename ToTy,
+ typename = typename std::enable_if<std::is_assignable<
+ typename ToTy::ValueType &, Ty>::value>::type>
+ operator ToTy() const {
+ return llvm::cast<ToTy>(*this);
+ }
+
/// Return the known Type
Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c06a48ee4b87c..c551fba93e367 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1751,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2147038..335ca1a60f8f3 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
More information about the Mlir-commits
mailing list