[Mlir-commits] [mlir] Add OneTypedResult::getResultOfType to simplify the result type casting logic (PR #120381)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 18 00:32:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (xiaoleis-nv)

<details>
<summary>Changes</summary>

## Description
This PR adds a `getResultOfType` member to the `OneTypedResult` class to simplify the result type casting logic.
Casting the result type is necessary when converting between its concrete type and interface type.
Without this member, one typically needs to call the `getResult` method followed by an explicit cast, which makes the code tedious. Introducing the `getResultOfType` member simplifies this process.

## Examples
Before this PR:
```cpp
auto targetShard = cast<TypedValue<ShapedType>>(
    builder.create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
        .getResult());
```
With this PR:
```cpp
auto targetShard = builder.create<AllSliceOp>(sourceShard, mesh,
                                             ArrayRef<MeshAxis>(splitMeshAxis),
                                             splitTensorAxis)
                       .getResultOfType<ShapedType>();
```

---
Full diff: https://github.com/llvm/llvm-project/pull/120381.diff


3 Files Affected:

- (modified) mlir/include/mlir/IR/OpDefinition.h (+7-2) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp (+12-8) 


``````````diff
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 59f094d6690991..ae28e1251bd954 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -694,11 +694,16 @@ class OneTypedResult {
   class Impl
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
-    mlir::TypedValue<ResultType> getResult() {
-      return cast<mlir::TypedValue<ResultType>>(
+    template <typename ValTy>
+    mlir::TypedValue<ValTy> getResultOfType() {
+      return mlir::cast<mlir::TypedValue<ValTy>>(
           this->getOperation()->getResult(0));
     }
 
+    mlir::TypedValue<ResultType> getResult() {
+      return getResultOfType<ResultType>();
+    }
+
     /// If the operation returns a single value, then the Op can be implicitly
     /// converted to a Value. This yields the value of the only result.
     operator mlir::TypedValue<ResultType>() { return getResult(); }
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..6d5a68ef4d0add 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -213,7 +213,7 @@ struct ReinterpretCastOpInterface
     auto reinterpretCast = cast<ReinterpretCastOp>(op);
     auto baseMemref = reinterpretCast.getSource();
     auto resultMemref =
-        cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
+        reinterpretCast.getResultOfType<BaseMemRefType>();
 
     builder.setInsertionPointAfter(op);
 
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 327ea0991e4e1e..5c268c06db08e6 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -85,13 +85,13 @@ handlePartialAxesDuringResharding(OpBuilder &builder,
   }
 
   builder.setInsertionPointAfterValue(sourceShard);
-  TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> resultValue =
       builder
           .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
                                sourceSharding.getMeshAttr().getLeafReference(),
                                allReduceMeshAxes, sourceShard,
                                sourceSharding.getPartialType())
-          .getResult());
+          .getResultOfType<ShapedType>();
 
   llvm::SmallVector<MeshAxis> remainingPartialAxes;
   llvm::copy_if(sourceShardingPartialAxesSet,
@@ -133,12 +133,12 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           MeshSharding sourceSharding,
                           TypedValue<ShapedType> sourceShard, MeshOp mesh,
                           int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> targetShard =
       builder
           .create<AllSliceOp>(sourceShard, mesh,
                               ArrayRef<MeshAxis>(splitMeshAxis),
                               splitTensorAxis)
-          .getResult());
+          .getResultOfType<ShapedType>();
   MeshSharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
   return {targetShard, targetSharding};
@@ -274,8 +274,10 @@ unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allGatherResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 
@@ -407,8 +409,10 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, mesh, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      builder
+          .create<tensor::CastOp>(targetShape, allToAllResult)
+          .getResultOfType<ShapedType>();
   return {targetShard, targetSharding};
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/120381


More information about the Mlir-commits mailing list